1use quote::{quote, ToTokens};
2use proc_macro2::{Span};
3use syn::parse::{Result, Error};
4use std::default::Default;
5use std::ffi::CStr;
6
7
8#[derive(Default)]
9struct VariantMeta {
10 pub name: Option<syn::LitByteStr>,
11}
12
13impl VariantMeta {
14 pub fn from_attrs(attrs: &[syn::Attribute]) -> Result<Self> {
16 let mut opts = VariantMeta::default();
17
18 for attr in attrs {
19 if attr.path.is_ident("cstr") {
20 opts.parse_meta(attr.parse_meta()?)?
21 }
22 }
23 Ok(opts)
24 }
25
26 pub fn parse_meta(&mut self, meta: syn::Meta) -> Result<()> {
28 match meta {
29 syn::Meta::List(nvs) => {
30 for nv in nvs.nested {
31 match nv {
32 syn::NestedMeta::Meta(syn::Meta::NameValue(nv)) => self.parse_nv(nv)?,
33 _ => return Err(Error::new_spanned(nv, "expected named argument (KEY = VALUE)"))
34 }
35 }
36 }
37 _ => return Err(Error::new_spanned(meta, "missing arguments: expected `cstr(...)`"))
38 }
39 Ok(())
40 }
41
42 fn parse_nv(&mut self, nv: syn::MetaNameValue) -> Result<()> {
44 if let Some(ident) = nv.path.get_ident() {
45 if ident == "name" {
46 Self::check_not_set(&self.name, ident)?;
47 match nv.lit {
48 syn::Lit::Str(s) => {
49 let mut name = s.value();
50 name.push('\0');
51 if CStr::from_bytes_with_nul(name.as_bytes()).is_err() {
52 return Err(Error::new_spanned(s, "string cannot contain nul bytes"));
53 }
54 self.name = Some(syn::LitByteStr::new(name.as_bytes(), s.span()));
55 return Ok(());
56 }
57 lit => { return Err(Error::new_spanned(lit, "expected string literal")); }
58 }
59 }
60 }
64 Err(Error::new_spanned(nv.path, "invalid named argument"))
65 }
66
67 fn check_not_set<T>(field: &Option<T>, tokens: impl ToTokens) -> Result<()> {
69 if field.is_some() {
70 Err(Error::new_spanned(tokens, "duplicate named argument"))
71 } else {
72 Ok(())
73 }
74 }
75}
76
77fn ident_to_byte_str_lit(ident: &syn::Ident) -> syn::LitByteStr {
79 let cstring = {
80 let mut s = ident.to_string();
81 s.push('\0');
82 s
83 };
84 syn::LitByteStr::new(cstring.as_bytes(), Span::call_site())
85}
86
87fn check_enum_attrs(input: &syn::DeriveInput) -> Result<()> {
89 for attr in &input.attrs {
90 if attr.path.is_ident("cstr") {
91 return Err(Error::new_spanned(attr, "attribute must be placed on variants"));
92 }
93 }
94 Ok(())
95}
96
97fn get_name_mapping<'a>(input: &'a syn::DeriveInput, unit_variants_only: bool) -> Result<(Vec<&'a syn::Ident>, Vec<syn::LitByteStr>)> {
99 check_enum_attrs(input)?;
100
101 let variants = match &input.data {
102 syn::Data::Enum(enm) => &enm.variants,
103 _ => return Err(Error::new(Span::call_site(), "target must be an enum")),
104 };
105
106 let mut idents = Vec::with_capacity(variants.len());
107 let mut bytestrs = Vec::with_capacity(variants.len());
108
109 #[allow(unused_variables)]
110 for variant in variants {
111 if unit_variants_only && variant.fields != syn::Fields::Unit {
112 return Err(Error::new_spanned(variant, "variant cannot have fields"));
113 }
114 let ident = &variant.ident;
116 let opts = VariantMeta::from_attrs(&variant.attrs)?;
117
118 bytestrs.push(opts.name.unwrap_or_else(|| ident_to_byte_str_lit(&ident)));
120 idents.push(ident);
121 }
122 Ok((idents, bytestrs))
123}
124
125
126#[proc_macro_derive(AsCStr, attributes(cstr))]
128pub fn derive_ascstr_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
129 let input = syn::parse_macro_input!(input as syn::DeriveInput);
130
131
132 let (var_idents, vals) = match get_name_mapping(&input, false) {
133 Ok(m) => m,
134 Err(e) => { return e.to_compile_error().into(); }
135 };
136
137 let ident = &input.ident;
138
139 let ts = quote! {
140 impl cstr_enum::AsCStr for #ident {
141 fn as_cstr(&self) -> &'static std::ffi::CStr {
142 match self {
143 #( Self::#var_idents{..} => unsafe {std::ffi::CStr::from_bytes_with_nul_unchecked(#vals) }, )*
144 }
145 }
146 }
147 };
148
149 ts.into()
150}
151
152
153#[proc_macro_derive(FromCStr, attributes(cstr))]
155pub fn derive_fromcstr_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
156 let input = syn::parse_macro_input!(input as syn::DeriveInput);
157
158 let (var_idents, mut vals) = match get_name_mapping(&input, true) {
159 Ok(m) => m,
160 Err(e) => { return e.to_compile_error().into(); }
161 };
162
163 for v in vals.iter_mut() {
164 let bytes = v.value();
165 *v = syn::LitByteStr::new(&bytes[..bytes.len() - 1], v.span())
166 }
167
168
169 let ident = &input.ident;
170 let error_msg = syn::LitStr::new(&format!("unexpected string while parsing for {} variant", ident), Span::call_site());
171
172 let ts = quote! {
173 impl cstr_enum::FromCStr for #ident {
174 type Err = &'static str;
175 fn from_cstr(s: &std::ffi::CStr) -> Result<Self, Self::Err> {
176 match s.to_bytes() {
177 #( #vals => Ok(Self::#var_idents), )*
178 _ => Err(#error_msg)
179 }
180 }
181 }
182 };
183
184 ts.into()
185}
186