extism_pdk_derive/
lib.rs

1use proc_macro2::{Ident, Span};
2use quote::quote;
3use syn::{parse_macro_input, FnArg, GenericArgument, ItemFn, ItemForeignMod, PathArguments};
4
5/// `plugin_fn` is used to define an Extism callable function to export
6///
7/// It should be added to a function you would like to export, the function should
8/// accept a parameter that implements `extism_pdk::FromBytes` and return a
9/// `extism_pdk::FnResult` that contains a value that implements
10/// `extism_pdk::ToBytes`. This maps input and output parameters to Extism input
11/// and output instead of using function arguments directly.
12///
13/// ## Example
14///
15/// ```rust
16/// use extism_pdk::{FnResult, plugin_fn};
17/// #[plugin_fn]
18/// pub fn greet(name: String) -> FnResult<String> {
19///   let s = format!("Hello, {name}");
20///   Ok(s)
21/// }
22/// ```
23#[proc_macro_attribute]
24pub fn plugin_fn(
25    _attr: proc_macro::TokenStream,
26    item: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28    let mut function = parse_macro_input!(item as ItemFn);
29
30    if !matches!(function.vis, syn::Visibility::Public(..)) {
31        panic!("extism_pdk::plugin_fn expects a public function");
32    }
33
34    let name = &function.sig.ident;
35    let constness = &function.sig.constness;
36    let unsafety = &function.sig.unsafety;
37    let generics = &function.sig.generics;
38    let inputs = &mut function.sig.inputs;
39    let output = &mut function.sig.output;
40    let block = &function.block;
41
42    let no_args = inputs.is_empty();
43
44    if name == "main" {
45        panic!(
46            "extism_pdk::plugin_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
47        )
48    }
49
50    match output {
51        syn::ReturnType::Default => panic!(
52            "extism_pdk::plugin_fn expects a return value, `()` may be used if no output is needed"
53        ),
54        syn::ReturnType::Type(_, t) => {
55            if let syn::Type::Path(p) = t.as_ref() {
56                if let Some(t) = p.path.segments.last() {
57                    if t.ident != "FnResult" {
58                        panic!("extism_pdk::plugin_fn expects a function that returns extism_pdk::FnResult");
59                    }
60                } else {
61                    panic!("extism_pdk::plugin_fn expects a function that returns extism_pdk::FnResult");
62                }
63            }
64        }
65    }
66
67    if no_args {
68        quote! {
69            #[no_mangle]
70            pub #constness #unsafety extern "C" fn #name() -> i32 {
71                #constness #unsafety fn inner #generics() #output {
72                    #block
73                }
74
75                let output = match inner() {
76                    core::result::Result::Ok(x) => x,
77                    core::result::Result::Err(rc) => {
78                        let err = format!("{:?}", rc.0);
79                        let mut mem = extism_pdk::Memory::from_bytes(&err).unwrap();
80                        unsafe {
81                            extism_pdk::extism::error_set(mem.offset());
82                        }
83                        return rc.1;
84                    }
85                };
86                extism_pdk::unwrap!(extism_pdk::output(&output));
87                0
88            }
89        }
90        .into()
91    } else {
92        quote! {
93            #[no_mangle]
94            pub #constness #unsafety extern "C" fn #name() -> i32 {
95                #constness #unsafety fn inner #generics(#inputs) #output {
96                    #block
97                }
98
99                let input = extism_pdk::unwrap!(extism_pdk::input());
100                let output = match inner(input) {
101                    core::result::Result::Ok(x) => x,
102                    core::result::Result::Err(rc) => {
103                        let err = format!("{:?}", rc.0);
104                        let mut mem = extism_pdk::Memory::from_bytes(&err).unwrap();
105                        unsafe {
106                            extism_pdk::extism::error_set(mem.offset());
107                        }
108                        return rc.1;
109                    }
110                };
111                extism_pdk::unwrap!(extism_pdk::output(&output));
112                0
113            }
114        }
115        .into()
116    }
117}
118
119/// `shared_fn` is used to define a function that will be exported by a plugin but is not directly
120/// callable by an Extism runtime. These functions can be used for runtime linking and mocking host
121/// functions for tests. If direct access to Wasm native parameters is needed, then a bare
122/// `extern "C" fn` should be used instead.
123///
124/// All arguments should implement `extism_pdk::ToBytes` and the return value should implement
125/// `extism_pdk::FromBytes`, if `()` or `SharedFnResult<()>` then no value will be returned.
126/// ## Example
127///
128/// ```rust
129/// use extism_pdk::{SharedFnResult, shared_fn};
130/// #[shared_fn]
131/// pub fn greet2(greeting: String, name: String) -> SharedFnResult<String> {
132///   let s = format!("{greeting}, {name}");
133///   Ok(name)
134/// }
135/// ```
136#[proc_macro_attribute]
137pub fn shared_fn(
138    _attr: proc_macro::TokenStream,
139    item: proc_macro::TokenStream,
140) -> proc_macro::TokenStream {
141    let mut function = parse_macro_input!(item as ItemFn);
142
143    if !matches!(function.vis, syn::Visibility::Public(..)) {
144        panic!("extism_pdk::shared_fn expects a public function");
145    }
146
147    let name = &function.sig.ident;
148    let constness = &function.sig.constness;
149    let unsafety = &function.sig.unsafety;
150    let generics = &function.sig.generics;
151    let inputs = &mut function.sig.inputs;
152    let output = &mut function.sig.output;
153    let block = &function.block;
154
155    let (raw_inputs, raw_args): (Vec<_>, Vec<_>) = inputs
156        .iter()
157        .enumerate()
158        .map(|(i, x)| {
159            let t = match x {
160                FnArg::Receiver(_) => {
161                    panic!("Receiver argument (self) cannot be used in extism_pdk::shared_fn")
162                }
163                FnArg::Typed(t) => &t.ty,
164            };
165            let arg = Ident::new(&format!("arg{i}"), Span::call_site());
166            (
167                quote! { #arg: extism_pdk::MemoryPointer<#t> },
168                quote! { #arg.get()? },
169            )
170        })
171        .unzip();
172
173    if name == "main" {
174        panic!(
175            "export_pdk::shared_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
176        )
177    }
178
179    let (no_result, raw_output) = match output {
180        syn::ReturnType::Default => (true, quote! {}),
181        syn::ReturnType::Type(_, t) => {
182            let mut is_unit = false;
183            if let syn::Type::Path(p) = t.as_ref() {
184                if let Some(t) = p.path.segments.last() {
185                    if t.ident != "SharedFnResult" {
186                        panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
187                    }
188                    match &t.arguments {
189                        PathArguments::AngleBracketed(args) => {
190                            if args.args.len() == 1 {
191                                match &args.args[0] {
192                                    GenericArgument::Type(syn::Type::Tuple(t)) => {
193                                        if t.elems.is_empty() {
194                                            is_unit = true;
195                                        }
196                                    }
197                                    _ => (),
198                                }
199                            }
200                        }
201                        _ => (),
202                    }
203                } else {
204                    panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
205                }
206            };
207            if is_unit {
208                (true, quote! {})
209            } else {
210                (false, quote! {-> u64 })
211            }
212        }
213    };
214
215    if no_result {
216        quote! {
217            #[no_mangle]
218            pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) {
219                #constness #unsafety fn inner #generics(#inputs) -> extism_pdk::SharedFnResult<()> {
220                    #block
221                }
222
223
224                let r = || inner(#(#raw_args,)*);
225                if let Err(rc) = r() {
226                    panic!("{}", rc.to_string());
227                }
228            }
229        }
230        .into()
231    } else {
232        quote! {
233            #[no_mangle]
234            pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) #raw_output {
235                #constness #unsafety fn inner #generics(#inputs) #output {
236                    #block
237                }
238
239                let r = || inner(#(#raw_args,)*);
240                match r().and_then(|x| extism_pdk::Memory::new(&x)) {
241                    core::result::Result::Ok(mem) => {
242                        mem.offset()
243                    },
244                    core::result::Result::Err(rc) => {
245                        panic!("{}", rc.to_string());
246                    }
247                }
248            }
249        }
250        .into()
251    }
252}
253
254/// `host_fn` is used to import a host function from an `extern` block
255#[proc_macro_attribute]
256pub fn host_fn(
257    attr: proc_macro::TokenStream,
258    item: proc_macro::TokenStream,
259) -> proc_macro::TokenStream {
260    let namespace = if let Ok(ns) = syn::parse::<syn::LitStr>(attr) {
261        ns.value()
262    } else {
263        "extism:host/user".to_string()
264    };
265
266    let item = parse_macro_input!(item as ItemForeignMod);
267    if item.abi.name.is_none() || item.abi.name.unwrap().value() != "ExtismHost" {
268        panic!("Expected `extern \"ExtismHost\"` block");
269    }
270    let functions = item.items;
271
272    let mut gen = quote!();
273
274    for function in functions {
275        if let syn::ForeignItem::Fn(function) = function {
276            let name = &function.sig.ident;
277            let original_inputs = function.sig.inputs.clone();
278            let output = &function.sig.output;
279
280            let vis = &function.vis;
281            let generics = &function.sig.generics;
282            let mut into_inputs = vec![];
283            let mut converted_inputs = vec![];
284
285            let (output_is_ptr, converted_output) = match output {
286                syn::ReturnType::Default => (false, quote!(())),
287                syn::ReturnType::Type(_, _) => (true, quote!(u64)),
288            };
289
290            for input in &original_inputs {
291                match input {
292                    syn::FnArg::Typed(t) => {
293                        let mut input = t.clone();
294                        input.ty = Box::new(syn::Type::Verbatim(quote!(u64)));
295                        converted_inputs.push(syn::FnArg::Typed(input));
296                        match &*t.pat {
297                            syn::Pat::Ident(i) => {
298                                into_inputs
299                                    .push(quote!(
300                                        extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset()
301                                    ));
302                            }
303                            _ => panic!("invalid host function argument"),
304                        }
305                    }
306                    _ => panic!("self arguments are not permitted in host functions"),
307                }
308            }
309
310            let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span());
311            let link_name = name.to_string();
312            let link_name = link_name.as_str();
313
314            let impl_block = quote! {
315                #[link(wasm_import_module = #namespace)]
316                extern "C" {
317                    #[link_name = #link_name]
318                    fn #impl_name(#(#converted_inputs),*) -> #converted_output;
319                }
320            };
321
322            let output = match output {
323                syn::ReturnType::Default => quote!(()),
324                syn::ReturnType::Type(_, ty) => quote!(#ty),
325            };
326
327            if output_is_ptr {
328                gen = quote! {
329                    #gen
330
331                    #impl_block
332
333                    #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
334                        let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*));
335                        <#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec())
336                    }
337                };
338            } else {
339                gen = quote! {
340                    #gen
341
342                    #impl_block
343
344                    #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
345                        let res = #impl_name(#(#into_inputs),*);
346                        core::result::Result::Ok(res)
347                    }
348                };
349            }
350        }
351    }
352
353    gen.into()
354}