enum_utils_from_str/
lib.rs

1//! Code generation for a compile-time trie-based mapping from strings to arbitrary values.
2
3mod trie;
4
5use std::collections::BTreeMap;
6use std::io;
7
8use quote::{quote, ToTokens};
9use proc_macro2::{Literal, Ident, TokenStream, Span};
10
11/// Generates a lookup function for all the key-value pairs contained in the tree.
12///
13/// # Examples
14///
15/// ```rust
16/// # #![recursion_limit="128"]
17/// # use quote::quote;
18/// use enum_utils_from_str::StrMapFunc;
19///
20/// # fn main() {
21/// // Compiling this trie into a lookup function...
22/// let mut code = vec![];
23/// StrMapFunc::new("custom_lookup", "bool")
24///     .entries(vec![
25///         ("yes", true),
26///         ("yep", true),
27///         ("no", false),
28///     ])
29///     .compile(&mut code);
30///
31/// // results in the following generated code.
32///
33/// # let generated = quote! {
34/// fn custom_lookup(s: &[u8]) -> Option<bool> {
35///     match s.len() {
36///         2 => if s[0] == b'n' && s[1] == b'o' {
37///             return Some(false);
38///         },
39///         3 => if s[0] == b'y' && s[1] == b'e' {
40///             if s[2] == b'p' {
41///                  return Some(true);
42///             } else if s[2] == b's' {
43///                 return Some(true);
44///             }
45///         },
46///
47///         _ => {}
48///     }
49///
50///     None
51/// }
52/// # };
53///
54/// # assert_eq!(String::from_utf8(code).unwrap(), format!("{}", generated));
55/// # }
56/// ```
57#[derive(Clone)]
58pub struct StrMapFunc {
59    atoms: Forest<TokenStream>,
60    func_name: Ident,
61    ret_ty: TokenStream,
62    case: Case,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum Case {
67    Sensitive,
68    Insensitive,
69}
70
71impl StrMapFunc {
72    pub fn new(func_name: &str, ret_ty: &str) -> Self {
73        StrMapFunc {
74            atoms: Default::default(),
75            func_name: Ident::new(func_name, Span::call_site()),
76            ret_ty: ret_ty.parse().unwrap(),
77            case: Case::Sensitive,
78        }
79    }
80
81    pub fn case(&mut self, case: Case) -> &mut Self {
82        self.case = case;
83        self
84    }
85
86    pub fn entry(&mut self, k: &str, v: impl ToTokens) -> &mut Self {
87        self.atoms.insert(k.as_bytes(), v.into_token_stream());
88        self
89    }
90
91    pub fn entries<'a, V: 'a>(&mut self, entries: impl IntoIterator<Item = (&'a str, V)>) -> &mut Self
92        where V: ToTokens,
93    {
94        for (s, v) in entries.into_iter() {
95            self.entry(s, v);
96        }
97
98        self
99    }
100
101    pub fn compile(&self, mut w: impl io::Write) -> io::Result<()> {
102        let tokens = self.into_token_stream();
103        w.write_all(format!("{}", tokens).as_bytes())
104    }
105}
106
107impl ToTokens for StrMapFunc {
108    fn to_tokens(&self, tokens: &mut TokenStream) {
109        let StrMapFunc { func_name, ret_ty, atoms, case } = self;
110
111        let match_arms = atoms.0.iter()
112            .map(|(&len, trie)| {
113                let branch = Forest::branch_tokens(trie, *case == Case::Insensitive);
114                let len = Literal::usize_unsuffixed(len);
115
116                quote!(#len => #branch)
117            });
118
119        let body = quote! {
120            match s.len() {
121                #( #match_arms, )*
122                _ => {}
123            }
124
125            None
126        };
127
128        tokens.extend(quote! {
129            fn #func_name(s: &[u8]) -> Option<#ret_ty> {
130                #body
131            }
132        });
133    }
134}
135
136/// A set of tries where each trie only stores strings of a single length.
137#[derive(Debug, Clone)]
138pub struct Forest<T>(BTreeMap<usize, trie::Node<T>>);
139
140impl<T> Default for Forest<T> {
141    fn default() -> Self {
142        Forest(Default::default())
143    }
144}
145
146impl<T> Forest<T> {
147    pub fn get(&mut self, bytes: &[u8]) -> Option<&T> {
148        self.0.get(&bytes.len())
149            .and_then(|n| n.get(bytes))
150    }
151
152    pub fn insert(&mut self, bytes: &[u8], value: T) -> Option<T> {
153        let node = self.0.entry(bytes.len()).or_default();
154        node.insert(bytes, value)
155    }
156}
157
158fn byte_literal(b: u8) -> TokenStream {
159    if b < 128 {
160        let c: String = char::from(b).escape_default().collect();
161        format!("b'{}'", c).parse().unwrap()
162    } else {
163        Literal::u8_unsuffixed(b).into_token_stream()
164    }
165}
166
167impl<T> Forest<T>
168    where T: ToTokens
169{
170    fn branch_tokens(node: &trie::Node<T>, ignore_case: bool) -> TokenStream {
171        use trie::TraversalOrder::*;
172
173        let mut tok = vec![TokenStream::new()];
174        let mut depth = 0;
175        let mut is_first_child = true;
176        let mut dfs = node.dfs();
177        while let Some((order, node)) = dfs.next() {
178            if node.bytes.is_empty() {
179                continue;
180            }
181
182            match order {
183                Pre => {
184                    if !is_first_child {
185                        tok.last_mut().unwrap().extend(quote!(else));
186                        is_first_child = true;
187                    }
188
189                    let i = (depth..depth+node.bytes.len()).map(Literal::usize_unsuffixed);
190                    let b = node.bytes.iter().cloned().map(byte_literal);
191
192                    if !ignore_case {
193                        tok.last_mut().unwrap().extend(quote!(if #( s[#i] == #b )&&*));
194                    } else {
195                        tok.last_mut().unwrap().extend(quote!(if #( s[#i].eq_ignore_ascii_case(&#b) )&&*));
196                    }
197
198                    tok.push(TokenStream::new());
199                    depth += node.bytes.len();
200
201                    if let Some(v) = node.value {
202                        // TODO: debug_assert_eq!(dfs.next().0, Post);
203
204                        tok.last_mut().unwrap().extend(quote!(return Some(#v);));
205                    }
206                }
207
208                Post => {
209                    let body = tok.pop().unwrap();
210                    tok.last_mut().unwrap().extend(quote!({ #body }));
211                    depth -= node.bytes.len();
212                    is_first_child = false;
213                }
214            }
215        }
216
217        let ret = tok.pop().unwrap();
218        assert!(tok.is_empty());
219        ret
220    }
221}