1use proc_macro2::{Ident, Span};
2use quote::quote;
3use syn::{parse_macro_input, FnArg, GenericArgument, ItemFn, ItemForeignMod, PathArguments};
4
5#[proc_macro_attribute]
24pub fn plugin_fn(
25 _attr: proc_macro::TokenStream,
26 item: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28 let mut function = parse_macro_input!(item as ItemFn);
29
30 if !matches!(function.vis, syn::Visibility::Public(..)) {
31 panic!("extism_pdk::plugin_fn expects a public function");
32 }
33
34 let name = &function.sig.ident;
35 let constness = &function.sig.constness;
36 let unsafety = &function.sig.unsafety;
37 let generics = &function.sig.generics;
38 let inputs = &mut function.sig.inputs;
39 let output = &mut function.sig.output;
40 let block = &function.block;
41
42 let no_args = inputs.is_empty();
43
44 if name == "main" {
45 panic!(
46 "extism_pdk::plugin_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
47 )
48 }
49
50 match output {
51 syn::ReturnType::Default => panic!(
52 "extism_pdk::plugin_fn expects a return value, `()` may be used if no output is needed"
53 ),
54 syn::ReturnType::Type(_, t) => {
55 if let syn::Type::Path(p) = t.as_ref() {
56 if let Some(t) = p.path.segments.last() {
57 if t.ident != "FnResult" {
58 panic!("extism_pdk::plugin_fn expects a function that returns extism_pdk::FnResult");
59 }
60 } else {
61 panic!("extism_pdk::plugin_fn expects a function that returns extism_pdk::FnResult");
62 }
63 }
64 }
65 }
66
67 if no_args {
68 quote! {
69 #[no_mangle]
70 pub #constness #unsafety extern "C" fn #name() -> i32 {
71 #constness #unsafety fn inner #generics() #output {
72 #block
73 }
74
75 let output = match inner() {
76 core::result::Result::Ok(x) => x,
77 core::result::Result::Err(rc) => {
78 let err = format!("{:?}", rc.0);
79 let mut mem = extism_pdk::Memory::from_bytes(&err).unwrap();
80 unsafe {
81 extism_pdk::extism::error_set(mem.offset());
82 }
83 return rc.1;
84 }
85 };
86 extism_pdk::unwrap!(extism_pdk::output(&output));
87 0
88 }
89 }
90 .into()
91 } else {
92 quote! {
93 #[no_mangle]
94 pub #constness #unsafety extern "C" fn #name() -> i32 {
95 #constness #unsafety fn inner #generics(#inputs) #output {
96 #block
97 }
98
99 let input = extism_pdk::unwrap!(extism_pdk::input());
100 let output = match inner(input) {
101 core::result::Result::Ok(x) => x,
102 core::result::Result::Err(rc) => {
103 let err = format!("{:?}", rc.0);
104 let mut mem = extism_pdk::Memory::from_bytes(&err).unwrap();
105 unsafe {
106 extism_pdk::extism::error_set(mem.offset());
107 }
108 return rc.1;
109 }
110 };
111 extism_pdk::unwrap!(extism_pdk::output(&output));
112 0
113 }
114 }
115 .into()
116 }
117}
118
119#[proc_macro_attribute]
137pub fn shared_fn(
138 _attr: proc_macro::TokenStream,
139 item: proc_macro::TokenStream,
140) -> proc_macro::TokenStream {
141 let mut function = parse_macro_input!(item as ItemFn);
142
143 if !matches!(function.vis, syn::Visibility::Public(..)) {
144 panic!("extism_pdk::shared_fn expects a public function");
145 }
146
147 let name = &function.sig.ident;
148 let constness = &function.sig.constness;
149 let unsafety = &function.sig.unsafety;
150 let generics = &function.sig.generics;
151 let inputs = &mut function.sig.inputs;
152 let output = &mut function.sig.output;
153 let block = &function.block;
154
155 let (raw_inputs, raw_args): (Vec<_>, Vec<_>) = inputs
156 .iter()
157 .enumerate()
158 .map(|(i, x)| {
159 let t = match x {
160 FnArg::Receiver(_) => {
161 panic!("Receiver argument (self) cannot be used in extism_pdk::shared_fn")
162 }
163 FnArg::Typed(t) => &t.ty,
164 };
165 let arg = Ident::new(&format!("arg{i}"), Span::call_site());
166 (
167 quote! { #arg: extism_pdk::MemoryPointer<#t> },
168 quote! { #arg.get()? },
169 )
170 })
171 .unzip();
172
173 if name == "main" {
174 panic!(
175 "export_pdk::shared_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
176 )
177 }
178
179 let (no_result, raw_output) = match output {
180 syn::ReturnType::Default => (true, quote! {}),
181 syn::ReturnType::Type(_, t) => {
182 let mut is_unit = false;
183 if let syn::Type::Path(p) = t.as_ref() {
184 if let Some(t) = p.path.segments.last() {
185 if t.ident != "SharedFnResult" {
186 panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
187 }
188 match &t.arguments {
189 PathArguments::AngleBracketed(args) => {
190 if args.args.len() == 1 {
191 match &args.args[0] {
192 GenericArgument::Type(syn::Type::Tuple(t)) => {
193 if t.elems.is_empty() {
194 is_unit = true;
195 }
196 }
197 _ => (),
198 }
199 }
200 }
201 _ => (),
202 }
203 } else {
204 panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
205 }
206 };
207 if is_unit {
208 (true, quote! {})
209 } else {
210 (false, quote! {-> u64 })
211 }
212 }
213 };
214
215 if no_result {
216 quote! {
217 #[no_mangle]
218 pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) {
219 #constness #unsafety fn inner #generics(#inputs) -> extism_pdk::SharedFnResult<()> {
220 #block
221 }
222
223
224 let r = || inner(#(#raw_args,)*);
225 if let Err(rc) = r() {
226 panic!("{}", rc.to_string());
227 }
228 }
229 }
230 .into()
231 } else {
232 quote! {
233 #[no_mangle]
234 pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) #raw_output {
235 #constness #unsafety fn inner #generics(#inputs) #output {
236 #block
237 }
238
239 let r = || inner(#(#raw_args,)*);
240 match r().and_then(|x| extism_pdk::Memory::new(&x)) {
241 core::result::Result::Ok(mem) => {
242 mem.offset()
243 },
244 core::result::Result::Err(rc) => {
245 panic!("{}", rc.to_string());
246 }
247 }
248 }
249 }
250 .into()
251 }
252}
253
254#[proc_macro_attribute]
256pub fn host_fn(
257 attr: proc_macro::TokenStream,
258 item: proc_macro::TokenStream,
259) -> proc_macro::TokenStream {
260 let namespace = if let Ok(ns) = syn::parse::<syn::LitStr>(attr) {
261 ns.value()
262 } else {
263 "extism:host/user".to_string()
264 };
265
266 let item = parse_macro_input!(item as ItemForeignMod);
267 if item.abi.name.is_none() || item.abi.name.unwrap().value() != "ExtismHost" {
268 panic!("Expected `extern \"ExtismHost\"` block");
269 }
270 let functions = item.items;
271
272 let mut gen = quote!();
273
274 for function in functions {
275 if let syn::ForeignItem::Fn(function) = function {
276 let name = &function.sig.ident;
277 let original_inputs = function.sig.inputs.clone();
278 let output = &function.sig.output;
279
280 let vis = &function.vis;
281 let generics = &function.sig.generics;
282 let mut into_inputs = vec![];
283 let mut converted_inputs = vec![];
284
285 let (output_is_ptr, converted_output) = match output {
286 syn::ReturnType::Default => (false, quote!(())),
287 syn::ReturnType::Type(_, _) => (true, quote!(u64)),
288 };
289
290 for input in &original_inputs {
291 match input {
292 syn::FnArg::Typed(t) => {
293 let mut input = t.clone();
294 input.ty = Box::new(syn::Type::Verbatim(quote!(u64)));
295 converted_inputs.push(syn::FnArg::Typed(input));
296 match &*t.pat {
297 syn::Pat::Ident(i) => {
298 into_inputs
299 .push(quote!(
300 extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset()
301 ));
302 }
303 _ => panic!("invalid host function argument"),
304 }
305 }
306 _ => panic!("self arguments are not permitted in host functions"),
307 }
308 }
309
310 let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span());
311 let link_name = name.to_string();
312 let link_name = link_name.as_str();
313
314 let impl_block = quote! {
315 #[link(wasm_import_module = #namespace)]
316 extern "C" {
317 #[link_name = #link_name]
318 fn #impl_name(#(#converted_inputs),*) -> #converted_output;
319 }
320 };
321
322 let output = match output {
323 syn::ReturnType::Default => quote!(()),
324 syn::ReturnType::Type(_, ty) => quote!(#ty),
325 };
326
327 if output_is_ptr {
328 gen = quote! {
329 #gen
330
331 #impl_block
332
333 #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
334 let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*));
335 <#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec())
336 }
337 };
338 } else {
339 gen = quote! {
340 #gen
341
342 #impl_block
343
344 #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
345 let res = #impl_name(#(#into_inputs),*);
346 core::result::Result::Ok(res)
347 }
348 };
349 }
350 }
351 }
352
353 gen.into()
354}