bmart_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::str::FromStr;
4use syn::{parse_macro_input, Lit, Meta, MetaNameValue};
5
6macro_rules! litstr {
7    ($lit: expr) => {
8        if let Lit::Str(s) = $lit {
9            s.value()
10        } else {
11            panic!("invalid string value")
12        }
13    };
14}
15
16/// Sorting for structures
17///
18/// Automatically implements Eq, PartialEq, Ord and PartialOrd for single-field comparison,
19/// supports structures with no or a single lifetime.
20///
21/// The default sorting field is "id", can be overriden with sorting(id = "field") attribute:
22///
23/// # Panics
24///
25/// Will panic on invalid attributes and if the expression is not a struct
26///
27/// ```rust
28/// use bmart_derive::Sorting;
29///
30/// #[derive(Sorting)]
31/// #[sorting(id = "name")]
32/// struct MyStruct {
33///     name: String,
34///     value: u32
35/// }
36/// ```
37#[proc_macro_derive(Sorting, attributes(sorting))]
38pub fn sorting_derive(input: TokenStream) -> TokenStream {
39    let sitem = parse_macro_input!(input as syn::ItemStruct);
40    let sid = &sitem.ident;
41    let mut owned = true;
42    for param in sitem.generics.params {
43        if let syn::GenericParam::Lifetime(_) = param {
44            owned = false;
45            break;
46        }
47    }
48    let mut id = "id".to_owned();
49    for a in &sitem.attrs {
50        if a.path.is_ident("sorting") {
51            if let Ok(nameval) = a.parse_args::<MetaNameValue>() {
52                if nameval.path.is_ident("id") {
53                    id = litstr!(nameval.lit);
54                } else {
55                    panic!("invalid attribute")
56                }
57            } else {
58                panic!("invalid attribute")
59            }
60        }
61    }
62    let i_id = format_ident!("{}", id);
63    let tr = if owned {
64        quote! {
65            impl Eq for #sid {}
66            impl Ord for #sid {
67                fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
68                    self.#i_id.cmp(&other.#i_id)
69                }
70            }
71            impl PartialOrd for #sid {
72                fn partial_cmp(&self, other: &Self) -> Option<::std::cmp::Ordering> {
73                    Some(self.cmp(other))
74                }
75            }
76            impl PartialEq for #sid {
77                fn eq(&self, other: &Self) -> bool {
78                    self.#i_id == other.#i_id
79                }
80            }
81        }
82    } else {
83        quote! {
84            impl<'srt> Eq for #sid<'srt> {}
85            impl<'srt> Ord for #sid<'srt> {
86                fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
87                    self.#i_id.cmp(&other.#i_id)
88                }
89            }
90            impl<'srt> PartialOrd for #sid<'srt> {
91                fn partial_cmp(&self, other: &Self) -> Option<::std::cmp::Ordering> {
92                    Some(self.cmp(other))
93                }
94            }
95            impl<'srt> PartialEq for #sid<'srt> {
96                fn eq(&self, other: &Self) -> bool {
97                    self.#i_id == other.#i_id
98                }
99            }
100        }
101    };
102    TokenStream::from(tr)
103}
104
105#[derive(Copy, Clone, Eq, PartialEq)]
106enum Case {
107    Lower,
108    Upper,
109    Snake,
110    ScrSnake,
111    Kebab,
112    ScrKebab,
113    Camel,
114}
115
116impl FromStr for Case {
117    type Err = ::std::convert::Infallible;
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        Ok(match s {
120            "lowercase" => Case::Lower,
121            "UPPERCASE" => Case::Upper,
122            "snake_case" => Case::Snake,
123            "SCREAMING_SNAKE_CASE" => Case::ScrSnake,
124            "kebab-case" => Case::Kebab,
125            "SCREAMING-KEBAB-CASE" => Case::ScrKebab,
126            "CamelCase" => Case::Camel,
127            _ => panic!("unsupported case: {}", s),
128        })
129    }
130}
131
132fn format_case(s: &str, case: Case) -> String {
133    match case {
134        Case::Camel => s.to_owned(),
135        Case::Lower => s.to_lowercase(),
136        Case::Upper => s.to_uppercase(),
137        Case::Snake | Case::ScrSnake | Case::Kebab | Case::ScrKebab => {
138            let sep = if case == Case::Snake || case == Case::ScrSnake {
139                "_"
140            } else {
141                "-"
142            };
143            let mut result = String::new();
144            for c in s.chars() {
145                if c.is_uppercase() && !result.is_empty() {
146                    result += sep;
147                }
148                result.push(c);
149            }
150            if case == Case::Snake || case == Case::Kebab {
151                result.to_lowercase()
152            } else {
153                result.to_uppercase()
154            }
155        }
156    }
157}
158
159struct EnumVar {
160    id: String,
161    name: Option<String>,
162    aliases: Vec<String>,
163    skip: bool,
164}
165
166impl EnumVar {
167    fn new(i: &syn::Ident) -> Self {
168        Self {
169            id: i.to_string(),
170            name: None,
171            aliases: Vec::new(),
172            skip: false,
173        }
174    }
175}
176
177/// Implements Display and FromStr for enums with no data attached. The default behavior is to use
178/// snake_case. Can be overriden with enumstr(rename_all = "case")
179///
180/// The possible case values: "lowercase", "UPPERCASE", "snake_case", "SCREAMING_SNAKE_CASE",
181/// "kebab-case", "SCREAMING-KEBAB-CASE". "CamelCase" (as-is)
182///
183/// Individual fields can be overriden with enumstr(rename = "name"), altered with enumstr(alias =
184/// "alias")
185///
186/// Fields, marked with enumstr(skip), are skipted in FromStr implementation.
187///
188/// To avoid additional dependancies, parse() Err type is String.
189///
190/// # Panics
191///
192/// Will panic on invalid attributes and if the expression is not an enum
193///
194/// ```rust
195/// use bmart_derive::EnumStr;
196///
197/// #[derive(EnumStr)]
198/// #[enumstr(rename_all = "snake_case")]
199/// enum MyEnum {
200///     Field1,
201///     Field2,
202///     #[enumstr(skip)]
203///     SecretField,
204///     VeryLongField,
205///     #[enumstr(rename = "another")]
206///     #[enumstr(alias = "a")]
207///     #[enumstr(alias = "af")]
208///     AnotherField
209/// }
210/// ```
211#[proc_macro_derive(EnumStr, attributes(enumstr))]
212pub fn enumstr_derive(input: TokenStream) -> TokenStream {
213    let sitem = parse_macro_input!(input as syn::ItemEnum);
214    let mut vars: Vec<EnumVar> = Vec::new();
215    for var in &sitem.variants {
216        let mut evar = EnumVar::new(&var.ident);
217        for a in &var.attrs {
218            if a.path.is_ident("enumstr") {
219                if let Ok(nameval) = a.parse_args::<MetaNameValue>() {
220                    if nameval.path.is_ident("rename") {
221                        evar.name = Some(litstr!(nameval.lit));
222                    } else if nameval.path.is_ident("alias") {
223                        evar.aliases.push(litstr!(nameval.lit));
224                    } else {
225                        panic!("invalid attribute")
226                    }
227                } else if let Ok(name) = a.parse_args::<Meta>() {
228                    if name.path().is_ident("skip") {
229                        evar.skip = true;
230                    } else {
231                        panic!("invalid attribute")
232                    }
233                } else {
234                    panic!("invalid attribute")
235                }
236            }
237        }
238        vars.push(evar);
239    }
240    let sid = &sitem.ident;
241    let mut case = Case::Snake;
242    for a in &sitem.attrs {
243        if a.path.is_ident("enumstr") {
244            if let Ok(nameval) = a.parse_args::<MetaNameValue>() {
245                if nameval.path.is_ident("rename_all") {
246                    case = litstr!(nameval.lit).parse().unwrap();
247                } else {
248                    panic!("invalid attribute")
249                }
250            } else {
251                panic!("invalid attribute")
252            }
253        }
254    }
255    let mut st_to = "match self {".to_owned();
256    let mut st_from = "match s {".to_owned();
257    for var in vars {
258        let name = if let Some(name) = var.name {
259            name
260        } else {
261            format_case(&var.id, case)
262        };
263        st_to += &format!("{}::{} => \"{}\",", sid, var.id, name);
264        if !var.skip {
265            st_from += &format!("\"{}\"", name);
266            for alias in var.aliases {
267                st_from += &format!(" | \"{}\"", alias);
268            }
269            st_from += &format!(" => Ok({}::{}),", sid, var.id);
270        }
271    }
272    st_to += "}";
273    st_from += "_ => Err(\"value unsupported: \".to_owned() + s)}";
274    let m_to: syn::ExprMatch = syn::parse_str(&st_to).unwrap();
275    let m_from: syn::ExprMatch = syn::parse_str(&st_from).unwrap();
276    let tr = quote! {
277        impl ::std::fmt::Display for #sid {
278            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
279                write!(f, "{}", #m_to)
280            }
281        }
282        impl ::std::str::FromStr for #sid {
283            type Err = String;
284            fn from_str(s: &str) -> Result<Self, Self::Err> {
285                #m_from
286            }
287        }
288    };
289    TokenStream::from(tr)
290}