lighter_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::{quote, quote_spanned};
5use std::mem;
6use syn::{
7    parse_macro_input, parse_quote, parse_quote_spanned, parse_str, spanned::Spanned, Arm, Expr,
8    ExprMatch, Ident, Lit, LitByte, Pat,
9};
10
11// return the body of the arm of `m` with the given byte as its pattern, if it exists
12fn find_arm(m: &mut ExprMatch, byte: u8) -> Option<&mut Expr> {
13    for arm in m.arms.iter_mut() {
14        // these ugly nested if statements are just to get at
15        // the literal byte (e.g. 1 in Option::Some(b'\x01'))
16        if let Pat::TupleStruct(expr) = &arm.pat {
17            if expr.path == parse_quote!(::core::option::Option::Some) && expr.pat.elems.len() == 1
18            {
19                if let Some(Pat::Lit(expr)) = expr.pat.elems.first() {
20                    if let Expr::Lit(expr) = expr.expr.as_ref() {
21                        if let Lit::Byte(b) = &expr.lit {
22                            if b.value() == byte {
23                                return Some(&mut arm.body);
24                            }
25                        }
26                    }
27                } else {
28                    panic!("weird arm {:?}", expr.pat.elems);
29                }
30            }
31        }
32    }
33
34    None
35}
36
37fn insert_arm(expr: &mut Expr, case: &[u8], arm: Arm, match_prefix: bool) {
38    match case {
39        // we are at a leaf: for an n-character string, we're n matches deep,
40        // and so we have no more chars to match. iff the iterator is empty &
41        // thus the string we're matching is over, or if we are only matching
42        // a prefix, the original arm's body runs
43        [] => {
44            let arm = Arm {
45                pat: if match_prefix {
46                    // when we are only matching a prefix, we don't care what
47                    // comes after the prefix, or whether the string ends after
48                    // the characters we've matched so far
49                    parse_quote!(_)
50                } else {
51                    parse_quote!(::core::option::Option::None)
52                },
53                ..arm
54            };
55
56            match expr {
57                // if expr is already a match statement, we can add to it as is
58                Expr::Match(m) => m.arms.push(arm),
59
60                // if our input is some other sort of statement, make it a wild
61                // match arm that will come first and always execute, such that
62                // arms added later (including the wild arm added to all match
63                // statements by insert_wild) unreachable. this may seem silly,
64                // but the goal is to trigger an "unreachable pattern" warning
65                // when the user does something like the following:
66                // ```
67                // lighter! { match s {
68                //     Prefix("") => println!("all strings start with the empty string"),
69                //     "hi" => unreachable!(),
70                //     _ => unreachable!(),
71                // } }
72                // ```
73                expr => {
74                    let e = mem::replace(expr, parse_quote!({}));
75                    *expr = parse_quote! {
76                        match __lighter_internal_iter.next() {
77                            _ => #e,
78                            #arm
79                        }
80                    };
81                }
82            }
83        }
84
85        // we are at a leaf for a prefix match: we don't need another match
86        // statement a level after this to check iterator.next() = None, as
87        // it's OK for the iterator to have more items after this one
88        [prefix] if match_prefix => {
89            // the format! is a workaround for a bug in
90            // LitByte::value where values created with
91            // LitByte::new are not parsed correctly
92            let mut b = parse_str::<LitByte>(&format!("b'\\x{:02x}'", prefix)).unwrap();
93            b.set_span(arm.pat.span());
94
95            let arm = Arm {
96                pat: parse_quote!(::core::option::Option::Some(#b)),
97                ..arm
98            };
99
100            match expr {
101                Expr::Match(m) => m.arms.push(arm),
102                expr => {
103                    let e = mem::replace(expr, parse_quote!({}));
104                    *expr = parse_quote! {
105                        match __lighter_internal_iter.next() {
106                            _ => #e,
107                            #arm
108                        }
109                    }
110                }
111            }
112        }
113
114        // there is at least one byte left to match, let's find or create
115        // another level of match statement for each next byte recursively
116        [prefix, suffix @ ..] => match expr {
117            Expr::Match(m) => {
118                let m_arm = match find_arm(m, *prefix) {
119                    // an arm already exists with our prefix byte;
120                    // insert our string's suffix relative to that
121                    Some(m_arm) => m_arm,
122
123                    // an arm does not yet exist for this prefix
124                    None => {
125                        // the format! is a workaround for a bug in
126                        // LitByte::value where values created with
127                        // LitByte::new are not parsed correctly
128                        let mut b = parse_str::<LitByte>(&format!("b'\\x{:02x}'", prefix)).unwrap();
129                        b.set_span(arm.pat.span());
130
131                        // TODO: parse_quote_spanned! ?
132                        m.arms.push(parse_quote! {
133                            ::core::option::Option::Some(#b) => match __lighter_internal_iter.next() {},
134                        });
135
136                        m.arms.last_mut().unwrap().body.as_mut()
137                    }
138                };
139
140                insert_arm(m_arm, suffix, arm, match_prefix);
141            }
142            expr => {
143                // the format! is a workaround for a bug in
144                // LitByte::value where values created with
145                // LitByte::new are not parsed correctly
146                // (TODO: report this bug)
147                let mut b = parse_str::<LitByte>(&format!("b'\\x{:02x}'", prefix)).unwrap();
148                b.set_span(arm.pat.span());
149
150                // TODO: is there a simpler placeholder expression than {}?
151                let e = mem::replace(expr, parse_quote!({}));
152                *expr = parse_quote! {
153                    match __lighter_internal_iter.next() {
154                        _ => #e,
155                        ::core::option::Option::Some(#b) => match __lighter_internal_iter.next() {},
156                    }
157                };
158            }
159        },
160    }
161}
162
163// recursively append wild/fallback cases to every match expression that doesn't already have one
164fn insert_wild(expr: &mut Expr, wild: &[Arm]) {
165    if let Expr::Match(m) = expr {
166        let mut has_wild = false;
167        for arm in m.arms.iter_mut() {
168            insert_wild(arm.body.as_mut(), wild);
169            if let Pat::Wild(_) = arm.pat {
170                has_wild = true;
171            }
172        }
173
174        if !has_wild {
175            m.arms.extend_from_slice(wild);
176        }
177    }
178}
179
180// TODO: error handling
181// TODO: assert no attrs etc.
182fn parse_arm(match_out: &mut Expr, wild: &mut Vec<Arm>, arm: Arm, prefix: bool) {
183    match arm.pat {
184        Pat::Lit(ref expr) => match expr.expr.as_ref() {
185            Expr::Lit(expr) => match &expr.lit {
186                Lit::Str(expr) => insert_arm(match_out, expr.value().as_bytes(), arm, prefix),
187                // TODO: handle if guards
188                _ => todo!("non-str lit"),
189            },
190            _ => todo!("non-lit expr"),
191        },
192        Pat::TupleStruct(expr)
193            if expr.path == parse_quote!(Prefix) && expr.pat.elems.len() == 1 =>
194        {
195            let arm = Arm {
196                pat: expr.pat.elems.into_iter().next().unwrap(),
197                ..arm
198            };
199
200            parse_arm(match_out, wild, arm, true)
201        } // TODO
202        Pat::Or(expr) => {
203            //for pat in &expr.cases {
204            for pat in expr.cases {
205                parse_arm(
206                    match_out,
207                    wild,
208                    Arm {
209                        attrs: arm.attrs.clone(),
210                        pat,
211                        guard: arm.guard.clone(),
212                        fat_arrow_token: arm.fat_arrow_token,
213                        body: arm.body.clone(),
214                        comma: arm.comma,
215                    },
216                    prefix,
217                )
218            }
219        }
220        Pat::Wild(_) => wild.push(arm),
221        x => todo!("non-lit pat {:?}", x),
222        //_ => todo!("non-lit pat"),
223    }
224}
225
226#[proc_macro]
227pub fn lighter(input: TokenStream) -> TokenStream {
228    let ExprMatch {
229        attrs,
230        match_token,
231        expr,
232        brace_token,
233        arms,
234    } = parse_macro_input!(input as ExprMatch);
235    if !attrs.is_empty() {
236        panic!("I don't know what to do with attributes on a match statement");
237    }
238
239    let mut wild = Vec::new();
240    let mut match_out = Expr::Match(ExprMatch {
241        attrs,
242        match_token,
243        expr: parse_quote_spanned!(expr.span()=> __lighter_internal_iter.next()),
244        brace_token,
245        arms: Vec::new(), // TODO
246    });
247
248    for arm in arms {
249        parse_arm(&mut match_out, &mut wild, arm, false);
250    }
251
252    insert_wild(&mut match_out, &wild);
253
254    let krate = match crate_name("lighter") {
255        Ok(FoundCrate::Name(name)) => Ident::new(&name, Span::call_site()),
256        _ => parse_quote!(lighter),
257    };
258
259    let make_iter = quote_spanned! {expr.span()=>
260        (&mut &mut &mut ::#krate::__internal::Wrap(Some(#expr))).bytes()
261    };
262
263    TokenStream::from(quote! {
264        {
265            use ::#krate::__internal::*;
266            let mut __lighter_internal_iter = #make_iter;
267            #match_out
268        }
269    })
270}
271
272/*
273// TODO
274#[cfg(test)]
275mod tests {
276    #[test]
277    fn it_works() {
278        let result = 2 + 2;
279        assert_eq!(result, 4);
280    }
281}
282*/