1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#![allow(unused_variables)]
#![allow(dead_code)]

use proc_macro::TokenStream;
#[cfg(any(feature = "bincode", feature = "bitcode", feature = "serde"))]
use quote::quote;
use syn::{
	parse::{Parse, ParseStream},
	punctuated::Punctuated,
	Ident, Token,
};

struct Args {
	encode: bool,
	decode: bool,
}

impl Parse for Args {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let options = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;

		let mut encode = false;
		let mut decode = false;

		for option in options {
			match option.to_string().as_str() {
				"encode" => encode = true,
				"decode" => decode = true,
				_ => return Err(syn::Error::new(option.span(), "unknown option")),
			}
		}

		if !encode && !decode {
			return Err(syn::Error::new(
				input.span(),
				"at least one of `encode` or `decode` must be enabled",
			));
		}

		Ok(Self { encode, decode })
	}
}

/// A utility macro for automatically deriving the correct traits
/// depending on the enabled features.
#[proc_macro_attribute]
pub fn apply(attr: TokenStream, input: TokenStream) -> TokenStream {
	let args = syn::parse_macro_input!(attr as Args);

	let mut tokens = TokenStream::default();

	#[cfg(feature = "serde")]
	{
		if args.encode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(axum_codec::__private::serde::Serialize)]
			}));
		}

		if args.decode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(axum_codec::__private::serde::Deserialize)]
			}));
		}

		tokens.extend(TokenStream::from(quote! {
			#[serde(crate = "axum_codec::__private::serde")]
		}));
	}

	#[cfg(feature = "bincode")]
	{
		if args.encode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(axum_codec::__private::bincode::Encode)]
			}));
		}

		if args.decode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(axum_codec::__private::bincode::Decode)]
			}));
		}

		tokens.extend(TokenStream::from(quote! {
			#[bincode(crate = "axum_codec::__private::bincode")]
		}));
	}

	// TODO: Implement #[bitcode(crate = "...")]
	// For now, use the real crate name so the error is nicer.
	#[cfg(feature = "bitcode")]
	{
		if args.encode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(bitcode::Encode)]
			}));
		}

		if args.decode {
			tokens.extend(TokenStream::from(quote! {
				#[derive(bitcode::Decode)]
			}));
		}
	}

	#[cfg(feature = "aide")]
	tokens.extend(TokenStream::from(quote! {
		#[derive(axum_codec::__private::schemars::JsonSchema)]
		#[schemars(crate = "axum_codec::__private::schemars")]
	}));

	// TODO: Implement #[validate(crate = "...")]
	// For now, use the real crate name so the error is nicer.
	#[cfg(feature = "validator")]
	if args.decode {
		tokens.extend(TokenStream::from(quote! {
			#[derive(validator::Validate)]
		}));
	}

	tokens.extend(input);
	tokens
}