1extern crate proc_macro;
6
7use proc_macro::TokenStream;
8
9fn get_byte_from_lit(lit: &syn::Lit) -> u8 {
10 if let syn::Lit::Byte(ref byte) = *lit {
11 byte.value()
12 } else {
13 panic!("Found a pattern that wasn't a byte")
14 }
15}
16
17fn get_byte_from_expr_lit(expr: &syn::Expr) -> u8 {
18 match *expr {
19 syn::Expr::Lit(syn::ExprLit { ref lit, .. }) => get_byte_from_lit(lit),
20 _ => unreachable!(),
21 }
22}
23
24fn parse_pat_to_table<'a>(
26 pat: &'a syn::Pat,
27 case_id: u8,
28 wildcard: &mut Option<&'a syn::Ident>,
29 table: &mut [u8; 256],
30) {
31 match pat {
32 syn::Pat::Lit(syn::PatLit { ref lit, .. }) => {
33 let value = get_byte_from_lit(lit);
34 if table[value as usize] == 0 {
35 table[value as usize] = case_id;
36 }
37 }
38 syn::Pat::Range(syn::PatRange {
39 ref start, ref end, ..
40 }) => {
41 let lo = get_byte_from_expr_lit(start.as_ref().unwrap());
42 let hi = get_byte_from_expr_lit(end.as_ref().unwrap());
43 for value in lo..hi {
44 if table[value as usize] == 0 {
45 table[value as usize] = case_id;
46 }
47 }
48 if table[hi as usize] == 0 {
49 table[hi as usize] = case_id;
50 }
51 }
52 syn::Pat::Wild(_) => {
53 for byte in table.iter_mut() {
54 if *byte == 0 {
55 *byte = case_id;
56 }
57 }
58 }
59 syn::Pat::Ident(syn::PatIdent { ref ident, .. }) => {
60 assert_eq!(*wildcard, None);
61 *wildcard = Some(ident);
62 for byte in table.iter_mut() {
63 if *byte == 0 {
64 *byte = case_id;
65 }
66 }
67 }
68 syn::Pat::Or(syn::PatOr { ref cases, .. }) => {
69 for case in cases {
70 parse_pat_to_table(case, case_id, wildcard, table);
71 }
72 }
73 _ => {
74 panic!("Unexpected pattern: {:?}. Buggy code ?", pat);
75 }
76 }
77}
78
79#[proc_macro]
93pub fn match_byte(input: TokenStream) -> TokenStream {
94 use syn::spanned::Spanned;
95 struct MatchByte {
96 expr: syn::Expr,
97 arms: Vec<syn::Arm>,
98 }
99
100 impl syn::parse::Parse for MatchByte {
101 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
102 Ok(MatchByte {
103 expr: {
104 let expr = input.parse()?;
105 input.parse::<syn::Token![,]>()?;
106 expr
107 },
108 arms: {
109 let mut arms = Vec::new();
110 while !input.is_empty() {
111 let arm = input.call(syn::Arm::parse)?;
112 assert!(arm.guard.is_none(), "match_byte doesn't support guards");
113 assert!(
114 arm.attrs.is_empty(),
115 "match_byte doesn't support attributes"
116 );
117 arms.push(arm);
118 }
119 arms
120 },
121 })
122 }
123 }
124 let MatchByte { expr, arms } = syn::parse_macro_input!(input);
125
126 let mut cases = Vec::new();
127 let mut table = [0u8; 256];
128 let mut match_body = Vec::new();
129 let mut wildcard = None;
130 for (i, ref arm) in arms.iter().enumerate() {
131 let case_id = i + 1;
132 let index = case_id as isize;
133 let name = syn::Ident::new(&format!("Case{case_id}"), arm.span());
134 let pat = &arm.pat;
135 parse_pat_to_table(pat, case_id as u8, &mut wildcard, &mut table);
136
137 cases.push(quote::quote!(#name = #index));
138 let body = &arm.body;
139 match_body.push(quote::quote!(Case::#name => { #body }))
140 }
141
142 let en = quote::quote!(enum Case {
143 #(#cases),*
144 });
145
146 let mut table_content = Vec::new();
147 for entry in table.iter() {
148 let name: syn::Path = syn::parse_str(&format!("Case::Case{entry}")).unwrap();
149 table_content.push(name);
150 }
151 let table = quote::quote!(static __CASES: [Case; 256] = [#(#table_content),*];);
152
153 if let Some(binding) = wildcard {
154 quote::quote!({ #en #table let #binding = #expr; match __CASES[#binding as usize] { #(#match_body),* }})
155 } else {
156 quote::quote!({ #en #table match __CASES[#expr as usize] { #(#match_body),* }})
157 }.into()
158}