1mod code_match;
2mod mangle;
3mod parse;
4mod util;
5mod tests;
6
7use parse::{Functions, map_to_cxx};
8use std::collections::HashMap;
9use proc_macro2::{TokenStream, Span};
10use proc_macro::{TokenStream as TS0, TokenTree};
11use std::env;
12use regex::Regex;
13use syn;
14use std::str::FromStr;
15use std::sync::Mutex;
16use crate::mangle::*;
17use crate::util::*;
18
19
20const TYPE_POD:i32 = 0;
21const TYPE_DTOR_TRIVIAL_MOVE:i32 = 1; lazy_static::lazy_static! {
24 static ref TYPE_STRATEGY: Mutex<HashMap<String, i32>> = Mutex::new(HashMap::new());
25}
26
27#[derive(Default)]
28#[allow(dead_code)]
29pub(crate) struct FFIBuilder{
30 is_cpp: bool,
31 extc_code: String,
32 norm_code: String,
33 err_str: String,
34 asm_used: bool
35}
36
37impl FFIBuilder {
38 pub fn new() -> Self { Self::default() }
39
40 fn dtor_code(tp: &str) -> String {
41 let tp = map_to_cxx(tp);
42 let dtor_name = dtor_name(tp);
43 let tp1 = tp.replace("<", "_").replace(">", "_");
44 format!("\t#[link_name = \"{dtor_name}\"]\n\tfn ffi__free_{tp1}(__o: *mut usize);\n")
45 }
46 fn sp_dtor_code(tp: &str) -> String {
47 let dtor_name = sp_dtor_name(tp);
48 format!("\t#[link_name = \"{dtor_name}\"]\n\tfn ffi__freeSP_{tp}(__o: *mut usize);\n")
49 }
50
51 fn show_dtor(self: &mut Self, tp: &str, rtwrap:&str, tp_cpp: &str)->Result<(), &str> {
52 let tp_strategy = match rtwrap {
53 "POD" => TYPE_POD,
54 "SharedPtr"|"UniquePtr"|"Vec"|"" => TYPE_DTOR_TRIVIAL_MOVE,
56 _ => {
57 self.err_str = format!("type {} not supported", rtwrap);
58 return Err(&self.err_str);
59 }
60 };
61 let mut mp = TYPE_STRATEGY.lock().unwrap();
62 let mut tp1 = if let Some(x) = mp.get(tp) {
63 if *x >> 16 != tp_strategy {
64 self.err_str = format!("type {tp} strategy conflict");
65 return Err(&self.err_str);
66 }
67 *x
68 } else {
69 tp_strategy << 16
70 };
71
72 if tp_strategy != TYPE_POD {
73 if (tp1 & 1 == 0) && rtwrap != "SharedPtr" {
74 tp1 |= 1;
75 self.extc_code += &Self::dtor_code(tp_cpp);
76 }
77 if rtwrap == "UniquePtr" && tp1 & 2 == 0 {
78 tp1 |= 2;
79 self.norm_code += &format!("
80impl ManDtor for {tp} {{
81 unsafe fn __dtor(ptr: *mut [u8;0]) {{
82 if ptr as usize != 0 {{
83 ffi__free_{tp}(ptr as *mut usize);
84 }}
85 }}
86}}");
87 }
88 if rtwrap == "SharedPtr" && tp1 & 4 == 0 {
89 tp1 |= 4;
90 self.extc_code += &Self::sp_dtor_code(tp);
91 self.norm_code += &format!("
92impl DropSP for {tp} {{
93 unsafe fn __drop_sp(ptr: *mut [u8;0]) {{
94 if ptr as usize != 0 {{
95 ffi__freeSP_{tp}(ptr as *mut usize);
96 }}
97 }}
98}}\n");
99 }
100 }
101 mp.insert(tp.to_string(), tp1);
102 Ok(())
103 }
104
105 fn get_link_name(self: &Self, func: &SimpFunc, is_cpp: bool)
106 -> Result<String, &'static str>
107 {
108 if ! is_cpp {
109 Ok(func.fn_name.to_string())
110 } else {
111 mangle(&func)
112 }
113 }
114
115 fn build_one_func(self:&mut Self, func: &SimpFunc, is_cpp: bool) -> Result<(), &str>{
116 let mut args_c = Vec::new();
117 let mut args_r = Vec::new();
118 let mut args_usage = Vec::new();
119 let mut fn_name = if let Some(pos) = func.fn_name.rfind("::") {
120 func.fn_name[pos + 2..].to_string()
121 }else{
122 func.fn_name.to_string()
123 };
124 if !func.klsname.is_empty() {
125 args_c.push("this__: *const u8".to_string());
126 args_r.push(format!("this__: CPtr<{}>", &func.klsname));
127 args_usage.push("this__.addr as *const u8".to_string());
128 fn_name = format!("{}__{}", &func.klsname, &func.fn_name);
129 }
130 let return_code_r = if func.is_async {
131 if func.ret.tp.is_empty() {
132 return Err("async function must have a return type");
133 }
134 if ! func.ret.tp_wrap.is_empty() {
135 self.err_str = format!("currently async function should only return non-generic types. wrap your type into a struct if need. unsupported return type {}", &func.ret.raw_str);
136 return Err(&self.err_str);
137 }
138 args_c.push("addr: usize".to_string());
139 args_usage.push("dyn_fv_addr".to_string());
140 format!(" -> {}", func.ret.tp_full)
141 } else {
142 match &func.ret.tp_wrap as &str {
143 "" if func.ret.tp.is_empty() => String::new(),
144 "POD" => format!(" -> {}", func.ret.tp),
145 _ => format!(" -> {}", func.ret.tp_full),
146 }
147 };
148
149 enum RetKind {
150 RtPrimitive,
151 RtCPtr,
152 RtSharedPtr,
153 RtObject,
154 }
155 let is_a64 = cfg!(target_arch="aarch64");
156 let mut ret_indirect = String::new();
157 let mut ret_kind = RetKind::RtPrimitive;
158 let return_code_c = match &func.ret.tp_wrap as &str {
159 "CPtr" => {
160 ret_kind = RetKind::RtCPtr;
161 " -> *const u8".to_string()
162 },
163 "SharedPtr"|"UniquePtr" => {
164 ret_kind = RetKind::RtSharedPtr;
165 if is_a64 {
166 self.asm_used = true;
167 ret_indirect = format!("let __rtox8 = &mut __rto as *mut {} as *mut u8;\n\t\t", &func.ret.tp_full);
168 } else {
169 args_c.push("__rto: * mut u8".to_string());
170 args_usage.push(format!("&mut __rto as *mut {} as *mut u8", &func.ret.tp_full));
171 }
172 "".to_string()
173 },
174 "" if func.is_async => String::new(),
175 "" if func.ret.tp.is_empty() => String::new(),
176 "" if func.ret.is_primitive => format!(" -> {}", func.ret.tp),
177 ""|"POD"|"Vec" => {
178 ret_kind = RetKind::RtObject;
179 if is_a64 {
180 self.asm_used = true;
181 ret_indirect = "let __rtox8 = &mut __rta as *mut usize;\n\t\t".to_string();
182 } else {
183 args_c.push("__rto: * mut usize".to_string());
184 args_usage.push("&mut __rta as *mut usize".to_string());
185 }
186 "".to_string()
187 }
188 _ => {
189 self.err_str = format!("return type {} not supported", &func.ret.raw_str);
190 return Err(&self.err_str);
191 }
192 };
193
194 for arg in &func.arg_list {
195 let mut args_x_done = false;
196 let is_ref = arg.tp_full.chars().next().unwrap() == '&';
197 match arg.tp_wrap.as_str() {
198 ""|"POD" if is_ref => {
199 match arg.tp.as_str() {
200 "CStr" => {
201 args_x_done = true;
202 args_c.push(format!("{}: *const i8", &arg.name));
203 args_r.push(format!("{}: &CStr", &arg.name));
204 args_usage.push(format!("{}.as_ptr()", &arg.name))
205 },
206 "str" => {
207 args_x_done = true;
208 args_c.push(format!("{}: *const u8, {}_len: usize", &arg.name, &arg.name));
209 args_r.push(format!("{}: &str", &arg.name));
210 args_usage.push(format!("{}.as_ptr()", &arg.name));
211 args_usage.push(format!("{}.len()", &arg.name));
212 },
213 "[u8]" => {
214 args_x_done = true;
215 args_c.push(format!("{}: *const u8, {}_len: usize", &arg.name, &arg.name));
216 args_r.push(format!("{}: &[u8]", &arg.name));
217 args_usage.push(format!("{}.as_ptr()", &arg.name));
218 args_usage.push(format!("{}.len()", &arg.name));
219 },
220 _ => args_usage.push(format!("{} as *{} {}", &arg.name, select_val(arg.is_const, "const", "mut"), &arg.tp)),
221 }
222 },
223 "CPtr" => args_usage.push(format!("{}.addr as * const u8", &arg.name)),
224 "Option" => match is_ref {
225 true => args_usage.push(format!("{}.as_ref().map_or(0 as * const {}, |x| x as * const {})", &arg.name, &arg.tp, &arg.tp)),
226 false => args_usage.push(format!("{}.map_or(0 as * const {}, |x| x as * const {})", &arg.name, &arg.tp, &arg.tp)),
227 },
228 "SharedPtr"|"UniquePtr" => args_usage.push(format!("{}.as_cptr().addr as * const {}", &arg.name, &arg.tp)),
230 _ if arg.is_primitive => args_usage.push(format!("{}", &arg.name)),
231 _ => {
232 let suggested_str = arg.raw_str.replace(":", ": &");
233 self.err_str = format!("function \"{}\" argument \"{}\" not supported, \
234 you should always use a reference for non-primitive types in interop functions.\n\
235 try use \"{}\" instead.", func.fn_name, &arg.raw_str, &suggested_str);
236 return Err(&self.err_str);
237 }
238 };
239 if !args_x_done {
240 args_c.push(format!("{}: {}", &arg.name, &arg.tp_asc));
241 args_r.push(format!("{}: {}", &arg.name, &arg.tp_full));
242 }
243 }
244
245 let link_name = if func.is_async {
246 let sa = SimpArg{
247 name: "dyn_fv_addr".to_string(),
248 tp: "usize".to_string(),
249 tp_full: "usize".to_string(),
250 tp_wrap: "".to_string(),
251 tp_cpp: format!("ValuePromise<{}>*", func.ret.tp_cpp),
252 is_const: false,
253 is_primitive: true,
254 raw_str: "usize".to_string(),
255 tp_asc: "usize".to_string()
256 };
257 let mut func1 = func.clone();
258 func1.ret.tp = "".to_string();
259 func1.ret.tp_cpp = "".to_string();
260 func1.ret.is_primitive = true;
261 func1.arg_list.insert(0, sa);
262 self.get_link_name(&func1, is_cpp)?
263 } else {
264 self.get_link_name(&func, is_cpp)?
265 };
266 let fnstart = format!("{} {}fn {}({}){}", &func.access,
267 if func.is_async { "async " } else { "" },
268 &fn_name, args_r.join(", "), return_code_r);
269 self.extc_code += &format!("\t#[link_name = \"{link_name}\"]\n\tfn ffi__{fn_name}({}){};\n",
270 args_c.join(", "), return_code_c);
271 match ret_kind {
272 RetKind::RtPrimitive => {},
273 _ => {
274 if let Err(s) = self.show_dtor(&func.ret.tp, &func.ret.tp_wrap, &func.ret.tp_cpp) {
275 self.err_str = s.to_string();
276 return Err(&self.err_str);
277 }
278 }
279 }
280 if ! ret_indirect.is_empty() {
281 if args_usage.len() > 0 {
284 let idx = args_usage.len() - 1;
285 let s0 = args_usage[idx].as_str();
286 let s0 = format!("{{let __argk={}; asm!(\"mov x8, {{xval1}}\", xval1=in(reg) __rtox8); __argk}}", s0);
287 args_usage[idx] = s0;
288 } else {
289 let s0 = "asm!(\"mov x8, {xval1}\", xval1=in(reg) __rtox8);\n\t\t";
290 ret_indirect += s0;
291 }
292 }
293 let usage = args_usage.join(", ");
294 let norm_code = match ret_kind {
295 RetKind::RtPrimitive if func.is_async => {
296 format!("let mut arr : [usize;2] = [0, 0];\n\
297 let mut fv= std::pin::pin!(FutureValue::<{}>::default());\n\
298 let dyn_fv_addr = fv.copy_vtbl(&mut arr);\n\
299 unsafe {{ ffi__{fn_name}({usage}); }}\n\
300 fv.await", &func.ret.tp)
301 },
302 RetKind::RtPrimitive => format!("unsafe {{ ffi__{fn_name}({usage}) }}"),
303 RetKind::RtCPtr => format!("CPtr{{ addr: unsafe {{ ffi__{fn_name}({usage}) as usize }}, _phantom: std::marker::PhantomData }}"),
304 RetKind::RtSharedPtr => {
305 let wrap1 = &func.ret.tp_wrap as &str;
306 let ret_type = &func.ret.tp as &str;
307 format!("let mut __rto = {wrap1}::<{ret_type}>::default();\n\
308 \tunsafe {{ {ret_indirect} ffi__{fn_name}({usage}); }}\n\
309 \t__rto")
310 },
311 RetKind::RtObject => {
312 let mut ret_type = &func.ret.tp_full as &str;
313 let tp1 = func.ret.tp_cpp.replace("<", "_").replace(">", "_");
314 let call_free = match func.ret.tp_wrap.as_str() {
315 "POD" => {
316 ret_type = &func.ret.tp as &str;
317 "".to_string()
318 }, _ => format!("ffi__free_{}(&mut __rta as *mut usize);\n\t\t", &tp1),
320 };
321 format!("const SZ:usize = (std::mem::size_of::<{ret_type}>()+16)/8;\n\
322 \tlet mut __rta : [usize;SZ] = [0;SZ];\n\
323 \tunsafe {{ {ret_indirect}\n\
324 \t\tffi__{fn_name}({usage}); \n\
325 \t\tlet __rto = (*(&__rta as *const usize as *const {ret_type})).clone();\n\
326 \t\t{call_free}__rto\n\
327 \t}}")
328 },
329 };
331 self.norm_code += &format!("{fnstart} {{\n\t{norm_code}\n}}\n");
332 Ok(())
333 }
334
335 pub fn build_bridge_code(self: &mut Self, input: TokenStream) -> Result<TokenStream, &str> {
336 let mut xxx = Functions::new();
337 if let Err(s) = xxx.parse_ts(input) {
338 self.err_str = s.to_string();
339 return Err(&self.err_str);
340 }
341
342 for func in &xxx.funcs {
343 if let Err(_) = self.build_one_func(func, xxx.is_cpp) {
344 return Err(&self.err_str);
345 }
346 }
347 let extc_code = move_obj(&mut self.extc_code);
348 let norm_code = move_obj(&mut self.norm_code);
349 let use_asm = select_val(self.asm_used, "use std::arch::asm;\n", "");
350 let all_code = format!("{use_asm}extern \"C\" {{\n{extc_code}}}\n{norm_code}\n");
351 if env_as_bool("RUST_BRIDGE_DEBUG") {
352 println!("{}", all_code);
353 }
354 TokenStream::from_str(&all_code).map_err(|e| {
355 self.err_str = e.to_string();
356 self.err_str.as_str()
357 })
358 }
359}
360
361extern "C" {
362 fn enable_msvc_debug_c();
363}
364
365#[proc_macro_attribute]
366pub fn bridge(_args: TS0, input: TS0) -> TS0 {
367 let mut bb = FFIBuilder::new();
368 match bb.build_bridge_code(input.into()) {
369 Ok(code) => code.into(),
370 Err(e) => syn::Error::new(Span::call_site(), e).to_compile_error().into()
371 }
372}
373
374#[proc_macro_attribute]
375pub fn enable_msvc_debug(args: TS0, _input: TS0) -> TS0
376{
377 let enable_ = if let Some(TokenTree::Ident(val)) = args.into_iter().next() {
378 val.to_string().parse::<i32>().unwrap_or(-1)
379 } else {
380 -1
381 };
382 let is_debug = match enable_ {
383 0 => false,
384 1 => true,
385 _ => {
386 match env::var("OUT_DIR") {
387 Ok(profile) => Regex::new(r"[\\/]target[\\/]debug[\\/]").unwrap().is_match(&profile),
388 Err(_) => false,
389 }
390 }
391 };
392 if is_debug {
393 unsafe{ enable_msvc_debug_c(); }
394 }
395 TS0::new()
396}