use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, Data, DeriveInput, Expr, Fields, Ident, ItemFn, LitStr, Token,
};
struct SerializableArgs {
seal: bool,
open: bool,
}
impl Parse for SerializableArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(SerializableArgs { seal: true, open: true });
}
let mut seal = false;
let mut open = false;
loop {
let kw: Ident = input.parse()?;
match kw.to_string().as_str() {
"SEAL" => seal = true,
"OPEN" => open = true,
other => {
return Err(syn::Error::new(
kw.span(),
format!(
"#[serializable]: unknown argument `{other}`. \
Valid arguments: SEAL, OPEN"
),
))
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
} else {
break;
}
}
Ok(SerializableArgs { seal, open })
}
}
pub fn expand_serializable(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SerializableArgs);
let mut input = parse_macro_input!(item as DeriveInput);
let name = input.ident.clone();
let (impl_generics_ts, ty_generics_ts, where_clause_ts) = {
let (i, t, w) = input.generics.split_for_impl();
(quote!(#i), quote!(#t), quote!(#w))
};
let mut per_field_methods: Vec<TokenStream2> = vec![];
if let Data::Struct(ref mut ds) = input.data {
let fields_opt = match &mut ds.fields {
Fields::Named(f) => Some(&mut f.named),
Fields::Unnamed(f) => Some(&mut f.unnamed),
Fields::Unit => None,
};
if let Some(fields) = fields_opt {
for field in fields.iter_mut() {
let field_name = match &field.ident {
Some(id) => id.clone(),
None => continue, };
let mut found_key: Option<LitStr> = None;
field.attrs.retain(|a| {
if a.path().is_ident("serializable") {
if let Ok(lit) = a.parse_args_with(|inp: ParseStream| {
let kw: Ident = inp.parse()?;
if kw != "key" {
return Err(syn::Error::new(
kw.span(),
"#[serializable]: field attribute syntax is \
`#[serializable(key = \"your-key\")]`",
));
}
inp.parse::<Token![=]>()?;
inp.parse::<LitStr>()
}) {
found_key = Some(lit);
}
false } else {
true
}
});
if let Some(key_lit) = found_key {
if args.seal {
let seal_fn = format_ident!("seal_{}", field_name);
per_field_methods.push(quote! {
pub fn #seal_fn(
&self,
) -> ::std::result::Result<
::std::vec::Vec<u8>,
::toolkit_zero::serialization::SerializationError,
> {
::toolkit_zero::serialization::seal(
&self.#field_name,
::std::option::Option::Some(#key_lit.to_string()),
)
}
});
}
}
}
}
}
let seal_method = if args.seal {
quote! {
pub fn seal(
&self,
key: ::std::option::Option<::std::string::String>,
) -> ::std::result::Result<
::std::vec::Vec<u8>,
::toolkit_zero::serialization::SerializationError,
> {
::toolkit_zero::serialization::seal(self, key)
}
}
} else {
quote! {}
};
let open_method = if args.open {
quote! {
pub fn open(
bytes: &[u8],
key: ::std::option::Option<::std::string::String>,
) -> ::std::result::Result<
Self,
::toolkit_zero::serialization::SerializationError,
> {
::toolkit_zero::serialization::open::<Self, ::std::string::String>(bytes, key)
}
}
} else {
quote! {}
};
quote! {
#[derive(
::toolkit_zero::serialization::Encode,
::toolkit_zero::serialization::Decode,
)]
#[bincode(crate = "::toolkit_zero::serialization::bincode")]
#input
impl #impl_generics_ts #name #ty_generics_ts #where_clause_ts {
#seal_method
#open_method
#(#per_field_methods)*
}
}
.into()
}
struct SerializeArgs {
source: Expr,
path: Option<Expr>,
key: Option<Expr>,
}
impl Parse for SerializeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let source: Expr = input.parse()?;
let mut path: Option<Expr> = None;
let mut key: Option<Expr> = None;
while input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
let kw: Ident = input.parse()?;
match kw.to_string().as_str() {
"path" => {
input.parse::<Token![=]>()?;
path = Some(input.parse::<Expr>()?);
}
"key" => {
input.parse::<Token![=]>()?;
key = Some(input.parse::<Expr>()?);
}
other => {
return Err(syn::Error::new(
kw.span(),
format!(
"#[serialize]: unknown keyword `{other}`. \
Valid keywords: path = <expr>, key = <expr>"
),
));
}
}
}
Ok(SerializeArgs { source, path, key })
}
}
pub fn expand_serialize(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SerializeArgs);
let func = parse_macro_input!(item as ItemFn);
let key_arg = match &args.key {
Some(k) => quote! { ::std::option::Option::Some(#k) },
None => quote! { ::std::option::Option::<::std::string::String>::None },
};
let source = &args.source;
match &args.path {
Some(path_expr) => quote! {
::std::fs::write(
#path_expr,
::toolkit_zero::serialization::seal(&#source, #key_arg)?,
)?;
},
None => {
let var_name = &func.sig.ident;
let ret_ty = match &func.sig.output {
syn::ReturnType::Type(_, ty) => ty.as_ref(),
syn::ReturnType::Default => {
return syn::Error::new(
func.sig.ident.span(),
"#[serialize]: a return type is required in variable mode. \
Example: `fn blob() -> Vec<u8> {}`",
)
.to_compile_error()
.into();
}
};
quote! {
let #var_name: #ret_ty =
::toolkit_zero::serialization::seal(&#source, #key_arg)?;
}
}
}
.into()
}
enum DeserializeSource {
Blob(Expr),
Path(Expr),
}
struct DeserializeArgs {
source: DeserializeSource,
key: Option<Expr>,
}
impl Parse for DeserializeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let source = if input.peek(Ident) {
let fork = input.fork();
let kw: Ident = fork.parse()?;
if kw == "path" && fork.peek(Token![=]) {
input.parse::<Ident>()?; input.parse::<Token![=]>()?;
DeserializeSource::Path(input.parse::<Expr>()?)
} else {
DeserializeSource::Blob(input.parse::<Expr>()?)
}
} else {
DeserializeSource::Blob(input.parse::<Expr>()?)
};
let mut key: Option<Expr> = None;
while input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
let kw: Ident = input.parse()?;
match kw.to_string().as_str() {
"key" => {
input.parse::<Token![=]>()?;
key = Some(input.parse::<Expr>()?);
}
other => {
return Err(syn::Error::new(
kw.span(),
format!(
"#[deserialize]: unknown keyword `{other}`. \
Valid keywords: key = <expr>"
),
));
}
}
}
Ok(DeserializeArgs { source, key })
}
}
pub fn expand_deserialize(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as DeserializeArgs);
let func = parse_macro_input!(item as ItemFn);
let var_name = &func.sig.ident;
let ret_ty = match &func.sig.output {
syn::ReturnType::Type(_, ty) => ty.as_ref(),
syn::ReturnType::Default => {
return syn::Error::new(
func.sig.ident.span(),
"#[deserialize]: a return type is required. \
Example: `fn config() -> MyStruct {}`",
)
.to_compile_error()
.into();
}
};
let key_arg = match &args.key {
Some(k) => quote! { ::std::option::Option::Some(#k) },
None => quote! { ::std::option::Option::<::std::string::String>::None },
};
let bytes_expr = match &args.source {
DeserializeSource::Blob(expr) => quote! { &#expr },
DeserializeSource::Path(path_expr) => quote! { &::std::fs::read(#path_expr)? },
};
quote! {
let #var_name: #ret_ty =
::toolkit_zero::serialization::open::<#ret_ty, _>(#bytes_expr, #key_arg)?;
}
.into()
}