extism_convert_macros/
lib.rs

1use std::iter;
2
3use manyhow::{ensure, error_message, manyhow, Result};
4use proc_macro_crate::{crate_name, FoundCrate};
5use quote::{format_ident, quote, ToTokens};
6use syn::{parse_quote, Attribute, DeriveInput, Path};
7
8/// Tries to resolve the path to `extism_convert` dynamically, falling back to feature flags when unsuccessful.
9fn convert_path() -> Path {
10    match (
11        crate_name("extism"),
12        crate_name("extism-convert"),
13        crate_name("extism-pdk"),
14    ) {
15        (Ok(FoundCrate::Name(name)), ..) => {
16            let ident = format_ident!("{name}");
17            parse_quote!(::#ident::convert)
18        }
19        (_, Ok(FoundCrate::Name(name)), ..) | (.., Ok(FoundCrate::Name(name))) => {
20            let ident = format_ident!("{name}");
21            parse_quote!(::#ident)
22        }
23        (Ok(FoundCrate::Itself), ..) => parse_quote!(::extism::convert),
24        (_, Ok(FoundCrate::Itself), ..) => parse_quote!(::extism_convert),
25        (.., Ok(FoundCrate::Itself)) => parse_quote!(::extism_pdk),
26        _ if cfg!(feature = "extism-path") => parse_quote!(::extism::convert),
27        _ if cfg!(feature = "extism-pdk-path") => parse_quote!(::extism_pdk),
28        _ => parse_quote!(::extism_convert),
29    }
30}
31
32fn extract_encoding(attrs: &[Attribute]) -> Result<Path> {
33    let encodings: Vec<_> = attrs
34        .iter()
35        .filter(|attr| attr.path().is_ident("encoding"))
36        .collect();
37    ensure!(!encodings.is_empty(), "encoding needs to be specified"; try = "`#[encoding(Json)]`");
38    ensure!(encodings.len() < 2, encodings[1], "only one encoding can be specified"; try = "remove `{}`", encodings[1].to_token_stream());
39
40    Ok(encodings[0].parse_args().map_err(
41        |e| error_message!(e.span(), "{e}"; note= "expects a path"; try = "`#[encoding(Json)]`"),
42    )?)
43}
44
45#[manyhow]
46#[proc_macro_derive(ToBytes, attributes(encoding))]
47pub fn to_bytes(
48    DeriveInput {
49        attrs,
50        ident,
51        generics,
52        ..
53    }: DeriveInput,
54) -> Result {
55    let encoding = extract_encoding(&attrs)?;
56    let convert = convert_path();
57
58    let (_, type_generics, _) = generics.split_for_impl();
59
60    let mut generics = generics.clone();
61    generics.make_where_clause().predicates.push(
62        parse_quote!(for<'__to_bytes_b> #encoding<&'__to_bytes_b Self>: #convert::ToBytes<'__to_bytes_b>)
63    );
64    generics.params = iter::once(parse_quote!('__to_bytes_a))
65        .chain(generics.params)
66        .collect();
67    let (impl_generics, _, where_clause) = generics.split_for_impl();
68
69    Ok(quote! {
70        impl #impl_generics #convert::ToBytes<'__to_bytes_a> for #ident #type_generics #where_clause
71        {
72            type Bytes = ::std::vec::Vec<u8>;
73
74            fn to_bytes(&self) -> Result<Self::Bytes, #convert::Error> {
75                #convert::ToBytes::to_bytes(&#encoding(self)).map(|__bytes| __bytes.as_ref().to_vec())
76            }
77        }
78
79    })
80}
81
82#[manyhow]
83#[proc_macro_derive(FromBytes, attributes(encoding))]
84pub fn from_bytes(
85    DeriveInput {
86        attrs,
87        ident,
88        mut generics,
89        ..
90    }: DeriveInput,
91) -> Result {
92    let encoding = extract_encoding(&attrs)?;
93    let convert = convert_path();
94    generics
95        .make_where_clause()
96        .predicates
97        .push(parse_quote!(#encoding<Self>: #convert::FromBytesOwned));
98    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
99    Ok(quote! {
100        impl #impl_generics #convert::FromBytesOwned for #ident #type_generics #where_clause
101        {
102            fn from_bytes_owned(__data: &[u8]) -> Result<Self, #convert::Error> {
103                <#encoding<Self> as #convert::FromBytesOwned>::from_bytes_owned(__data).map(|__encoding| __encoding.0)
104            }
105        }
106
107    })
108}