cstr_enum_derive/
lib.rs

1use quote::{quote, ToTokens};
2use proc_macro2::{Span};
3use syn::parse::{Result, Error};
4use std::default::Default;
5use std::ffi::CStr;
6
7
8#[derive(Default)]
9struct VariantMeta {
10  pub name: Option<syn::LitByteStr>,
11}
12
13impl VariantMeta {
14  /// Build the enum variant meta info from attributes
15  pub fn from_attrs(attrs: &[syn::Attribute]) -> Result<Self> {
16    let mut opts = VariantMeta::default();
17
18    for attr in attrs {
19      if attr.path.is_ident("cstr") {
20        opts.parse_meta(attr.parse_meta()?)?
21      }
22    }
23    Ok(opts)
24  }
25
26  /// Parse a single #[cstr(...)] item on a variant
27  pub fn parse_meta(&mut self, meta: syn::Meta) -> Result<()> {
28    match meta {
29      syn::Meta::List(nvs) => {
30        for nv in nvs.nested {
31          match nv {
32            syn::NestedMeta::Meta(syn::Meta::NameValue(nv)) => self.parse_nv(nv)?,
33            _ => return Err(Error::new_spanned(nv, "expected named argument (KEY = VALUE)"))
34          }
35        }
36      }
37      _ => return Err(Error::new_spanned(meta, "missing arguments: expected `cstr(...)`"))
38    }
39    Ok(())
40  }
41
42  /// Parse a single item in the list of name-value pairs inside the #[cstr(...)]
43  fn parse_nv(&mut self, nv: syn::MetaNameValue) -> Result<()> {
44    if let Some(ident) = nv.path.get_ident() {
45      if ident == "name" {
46        Self::check_not_set(&self.name, ident)?;
47        match nv.lit {
48          syn::Lit::Str(s) => {
49            let mut name = s.value();
50            name.push('\0');
51            if CStr::from_bytes_with_nul(name.as_bytes()).is_err() {
52              return Err(Error::new_spanned(s, "string cannot contain nul bytes"));
53            }
54            self.name = Some(syn::LitByteStr::new(name.as_bytes(), s.span()));
55            return Ok(());
56          }
57          lit => { return Err(Error::new_spanned(lit, "expected string literal")); }
58        }
59      }
60      // future attributes can be added here.  Annoyingly, a match statement doesn't work
61      // since `ident` is of a different type
62      // ...
63    }
64    Err(Error::new_spanned(nv.path, "invalid named argument"))
65  }
66
67  /// Check the field hasn't been set before by another attribute item
68  fn check_not_set<T>(field: &Option<T>, tokens: impl ToTokens) -> Result<()> {
69    if field.is_some() {
70      Err(Error::new_spanned(tokens, "duplicate named argument"))
71    } else {
72      Ok(())
73    }
74  }
75}
76
77/// Convert an ident to a nul-terminated byte-string literal.
78fn ident_to_byte_str_lit(ident: &syn::Ident) -> syn::LitByteStr {
79  let cstring = {
80    let mut s = ident.to_string();
81    s.push('\0');
82    s
83  };
84  syn::LitByteStr::new(cstring.as_bytes(), Span::call_site())
85}
86
87/// Check that #[cstr(...)] is not applied to the enum itself
88fn check_enum_attrs(input: &syn::DeriveInput) -> Result<()> {
89  for attr in &input.attrs {
90    if attr.path.is_ident("cstr") {
91      return Err(Error::new_spanned(attr, "attribute must be placed on variants"));
92    }
93  }
94  Ok(())
95}
96
97/// Retrieve the name mapping between enum variants and their CStr representations
98fn get_name_mapping<'a>(input: &'a syn::DeriveInput, unit_variants_only: bool) -> Result<(Vec<&'a syn::Ident>, Vec<syn::LitByteStr>)> {
99  check_enum_attrs(input)?;
100
101  let variants = match &input.data {
102    syn::Data::Enum(enm) => &enm.variants,
103    _ => return Err(Error::new(Span::call_site(), "target must be an enum")),
104  };
105
106  let mut idents = Vec::with_capacity(variants.len());
107  let mut bytestrs = Vec::with_capacity(variants.len());
108
109  #[allow(unused_variables)]
110  for variant in variants {
111    if unit_variants_only && variant.fields != syn::Fields::Unit {
112      return Err(Error::new_spanned(variant, "variant cannot have fields"));
113    }
114    // parse name from attributes
115    let ident = &variant.ident;
116    let opts = VariantMeta::from_attrs(&variant.attrs)?;
117
118    // Default to the ident of the variant
119    bytestrs.push(opts.name.unwrap_or_else(|| ident_to_byte_str_lit(&ident)));
120    idents.push(ident);
121  }
122  Ok((idents, bytestrs))
123}
124
125
126/// Derive macro for the [`AsCStr`] trait.  May only be applied to enums.
127#[proc_macro_derive(AsCStr, attributes(cstr))]
128pub fn derive_ascstr_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
129  let input = syn::parse_macro_input!(input as syn::DeriveInput);
130
131
132  let (var_idents, vals) = match get_name_mapping(&input, false) {
133    Ok(m) => m,
134    Err(e) => { return e.to_compile_error().into(); }
135  };
136
137  let ident = &input.ident;
138
139  let ts = quote! {
140       impl cstr_enum::AsCStr for #ident {
141            fn as_cstr(&self) -> &'static std::ffi::CStr {
142                match self {
143                    #( Self::#var_idents{..} => unsafe {std::ffi::CStr::from_bytes_with_nul_unchecked(#vals) }, )*
144                }
145            }
146       }
147    };
148
149  ts.into()
150}
151
152
153/// Derive macro for the [`FromCStr`] trait.  May only be applied to enums whose variants have no fields.
154#[proc_macro_derive(FromCStr, attributes(cstr))]
155pub fn derive_fromcstr_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
156  let input = syn::parse_macro_input!(input as syn::DeriveInput);
157
158  let (var_idents, mut vals) = match get_name_mapping(&input, true) {
159    Ok(m) => m,
160    Err(e) => { return e.to_compile_error().into(); }
161  };
162
163  for v in vals.iter_mut() {
164    let bytes = v.value();
165    *v = syn::LitByteStr::new(&bytes[..bytes.len() - 1], v.span())
166  }
167
168
169  let ident = &input.ident;
170  let error_msg = syn::LitStr::new(&format!("unexpected string while parsing for {} variant", ident), Span::call_site());
171
172  let ts = quote! {
173       impl cstr_enum::FromCStr for #ident {
174            type Err = &'static str;
175            fn from_cstr(s: &std::ffi::CStr) -> Result<Self, Self::Err> {
176                match s.to_bytes() {
177                    #( #vals => Ok(Self::#var_idents), )*
178                    _ => Err(#error_msg)
179                }
180            }
181       }
182    };
183
184  ts.into()
185}
186