scotch_guest_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse_macro_input, parse_quote, Expr, FnArg, ForeignItem, ItemFn, ItemForeignMod, Pat,
5    ReturnType, Signature, Stmt, Type, TypeReference,
6};
7
8fn is_atom_type(ty: &str) -> bool {
9    const ATOMS: &[&str] = &[
10        "bool", "char", "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64",
11    ];
12
13    ATOMS.iter().any(|&a| a == ty)
14}
15
16#[derive(Clone, Copy)]
17enum WrapMode {
18    Encoded,
19    Managed,
20}
21
22impl WrapMode {
23    fn wrap(self, ty: Type) -> Type {
24        match self {
25            WrapMode::Encoded => parse_quote!(scotch_guest::EncodedPtr<#ty>),
26            WrapMode::Managed => parse_quote!(scotch_guest::MemoryType),
27        }
28    }
29}
30
31enum TypeTranslation {
32    Original,
33    Wrapped(Type),
34}
35
36fn translate_type(ty: Type, mode: WrapMode, allow_owned: bool) -> TypeTranslation {
37    match ty {
38        Type::Path(ref path)
39            if is_atom_type(&path.path.segments.last().unwrap().ident.to_string()) =>
40        {
41            TypeTranslation::Original
42        }
43        Type::Reference(TypeReference {
44            lifetime: None,
45            mutability: None,
46            elem,
47            ..
48        }) => TypeTranslation::Wrapped(mode.wrap(*elem)),
49        Type::Array(_) | Type::Tuple(_) => TypeTranslation::Wrapped(mode.wrap(ty)),
50        Type::Path(_) if allow_owned => TypeTranslation::Wrapped(mode.wrap(ty)),
51        _ => unimplemented!("Type is unsupported, consider using a reference instead."),
52    }
53}
54
55#[derive(Default)]
56struct HostInputTranslation {
57    call_args: Vec<Expr>,
58    prelude: Vec<Stmt>,
59    epilogue: Vec<Stmt>,
60}
61
62fn translate_host_inputs<'a>(it: impl Iterator<Item = &'a mut FnArg>) -> HostInputTranslation {
63    let mut out = HostInputTranslation::default();
64
65    it.map(|arg| {
66        if let FnArg::Typed(arg) = arg {
67            arg
68        } else {
69            panic!("self is not allowed in host functions")
70        }
71    })
72    .map(|arg| {
73        if let Pat::Ident(name) = arg.pat.as_mut() {
74            (name.ident.clone(), &mut arg.ty)
75        } else {
76            panic!("Invalid function argument name")
77        }
78    })
79    .for_each(|(name, ty)| {
80        if let TypeTranslation::Wrapped(new) =
81            translate_type(ty.as_ref().clone(), WrapMode::Managed, false)
82        {
83            *ty = Box::new(new);
84            out.prelude
85                .push(parse_quote!(let #name = scotch_guest::ManagedPtr::new(#name).unwrap();));
86            out.epilogue.push(parse_quote!(#name.free();));
87            out.call_args.push(parse_quote!(#name.offset()));
88        } else {
89            out.call_args.push(parse_quote!(#name));
90        }
91    });
92
93    out
94}
95
96fn translate_host_output(ret: &mut ReturnType) -> Stmt {
97    let mut out = parse_quote!(return out;);
98
99    if let ReturnType::Type(_, ty) = ret {
100        if let TypeTranslation::Wrapped(new) =
101            translate_type(ty.as_ref().clone(), WrapMode::Managed, true)
102        {
103            *ty = Box::new(new);
104            out = parse_quote! {return {
105                let ptr = scotch_guest::ManagedPtr::with_size_by_address(out);
106                let value = ptr.read().expect("Guest received invalid ptr");
107                ptr.free();
108                value
109            };};
110        }
111    }
112
113    out
114}
115
116/// Macro used to annotate `extern` blocks that contain plugin imports.
117/// ```ignore
118/// #[scotch_guest::host_functions]
119/// extern "C" {
120///     fn print(val: &String);
121/// }
122/// ```
123#[proc_macro_attribute]
124pub fn host_functions(_: TokenStream, input: TokenStream) -> TokenStream {
125    let host_funcs = parse_macro_input!(input as ItemForeignMod);
126    let funcs = host_funcs
127        .items
128        .into_iter()
129        .map(|item| {
130            // I know let_else exists but unfortunatelly it breaks the formatting.
131            if let ForeignItem::Fn(func) = item {
132                func
133            } else {
134                panic!("Only functions are allowed in host_functions block")
135            }
136        })
137        .map(|mut func| {
138            let Signature {
139                ident,
140                inputs,
141                output,
142                ..
143            } = func.sig.clone();
144
145            let sig = &mut func.sig;
146            let ending = translate_host_output(&mut sig.output);
147
148            let fake_id = format_ident!("_host_{}", sig.ident);
149            sig.ident = fake_id.clone();
150
151            let HostInputTranslation {
152                prelude,
153                epilogue,
154                call_args,
155            } = translate_host_inputs(sig.inputs.iter_mut());
156
157            quote! {
158                fn #ident(#inputs) #output {
159                    extern "C" {
160                        #[link_name = stringify!(#ident)]
161                        #sig;
162                    }
163
164                    unsafe {
165                        #(#prelude)*
166                        let out = #fake_id(#(#call_args),*);
167                        #(#epilogue)*
168
169                        #ending
170                    }
171                }
172            }
173        });
174
175    let out = quote! {
176        #(#funcs)*
177    };
178    out.into()
179}
180
181#[derive(Default)]
182struct GuestInputTranslation {
183    prelude: Vec<Stmt>,
184}
185
186fn translate_guest_inputs<'a>(it: impl Iterator<Item = &'a mut FnArg>) -> GuestInputTranslation {
187    let mut out = GuestInputTranslation::default();
188
189    it.map(|arg| {
190        let FnArg::Typed(arg) = arg else { panic!("self is not allowed in guest functions") };
191        let Pat::Ident(id) = &*arg.pat else { panic!("Invalid function declation") };
192        (id.ident.clone(), &mut arg.ty)
193    })
194    .for_each(|(name, ty)| {
195        if let TypeTranslation::Wrapped(new) = translate_type(ty.as_ref().clone(), WrapMode::Encoded, false) {
196            out.prelude
197                .push(parse_quote!(let #name: #ty = &unsafe { #name.read().expect("Guest was given invalid pointer") };));
198            *ty = Box::new(new);
199        };
200    });
201
202    out
203}
204
205fn translate_guest_output(ret: &mut ReturnType) -> Stmt {
206    let mut out = parse_quote!(return out;);
207
208    if let ReturnType::Type(_, ty) = ret {
209        if let TypeTranslation::Wrapped(new) =
210            translate_type(ty.as_ref().clone(), WrapMode::Managed, true)
211        {
212            *ty = Box::new(new);
213            out = parse_quote!(return scotch_guest::ManagedPtr::new(&out).unwrap().offset(););
214        }
215    }
216
217    out
218}
219
220/// Macro used to annotate guest functions that should be exposed to the host.
221/// ```ignore
222/// #[scotch_guest::guest_function]
223/// fn add_up_list(items: &Vec<i32>) -> i32 {
224///     items.iter().sum::<i32>()
225/// }
226/// ```
227#[proc_macro_attribute]
228pub fn guest_function(_: TokenStream, input: TokenStream) -> TokenStream {
229    let mut item_fn = parse_macro_input!(input as ItemFn);
230    item_fn.attrs.push(parse_quote!(#[no_mangle]));
231    item_fn.sig.abi = Some(parse_quote!(extern "C"));
232
233    let GuestInputTranslation { prelude } = translate_guest_inputs(item_fn.sig.inputs.iter_mut());
234    let output = item_fn.sig.output.clone();
235    let epilogue = translate_guest_output(&mut item_fn.sig.output);
236    let body = item_fn.block;
237
238    item_fn.block = parse_quote!({
239        #(#prelude)*
240        let out = (move || #output #body)();
241        #epilogue
242    });
243
244    let out = quote! {
245        #item_fn
246    };
247
248    out.into()
249}