Skip to main content

ws_framer_macros/
lib.rs

1use itertools::Itertools;
2use proc_macro::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{parse::Parser, punctuated::Punctuated, Expr, Lit, Token};
5
6#[proc_macro]
7pub fn base64_impl(item: TokenStream) -> TokenStream {
8    let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
9    let args = parser.parse(item).unwrap();
10    if args.len() != 4 {
11        panic!("This macro requires 3 arguments (structName, \"CHARSET\", padding, std)")
12    }
13
14    let struct_name = if let Expr::Path(struct_name_expr) = args.get(0).unwrap() {
15        let segments = struct_name_expr.path.segments.iter().collect::<Vec<_>>();
16
17        if segments.len() != 1 {
18            panic!("First argument should be simple struct name (one segment)");
19        }
20
21        segments[0].ident.clone()
22    } else {
23        panic!("First argument not a Expr::Lit!");
24    };
25
26    let charset = if let Expr::Lit(charset_expr) = args.get(1).unwrap() {
27        if let Lit::Str(charset_str) = &charset_expr.lit {
28            charset_str.value()
29        } else {
30            panic!("Second argument not a string!");
31        }
32    } else {
33        panic!("Second argument not a Expr::Lit!");
34    };
35
36    let pad = if let Expr::Lit(pad_expr) = args.get(2).unwrap() {
37        if let Lit::Bool(pad_val) = &pad_expr.lit {
38            pad_val.value()
39        } else {
40            panic!("Third argument not a bool!");
41        }
42    } else {
43        panic!("Third argument not a Expr::Lit!");
44    };
45
46    let use_std = if let Expr::Lit(std_expr) = args.get(3).unwrap() {
47        if let Lit::Bool(std_val) = &std_expr.lit {
48            std_val.value()
49        } else {
50            panic!("Fourth argument not a bool!");
51        }
52    } else {
53        panic!("Fourth argument not a Expr::Lit!");
54    };
55
56    let encode_map = charset.chars().collect::<Vec<_>>();
57    let copied_encode_map = encode_map.clone().into_iter().unique().collect::<Vec<_>>();
58    if encode_map.len() != copied_encode_map.len() {
59        panic!("Characters cannot contain duplicates!");
60    }
61
62    let mut decode_map = vec![255; 255];
63    for i in 0..encode_map.len() {
64        let char_val = encode_map[i] as u8;
65        decode_map[char_val as usize] = i as u8;
66    }
67
68    let encode_map = encode_map
69        .iter()
70        .map(|c| {
71            quote! {
72                #c,
73            }
74        })
75        .collect::<Vec<_>>();
76
77    let decode_map = decode_map
78        .iter()
79        .map(|c| {
80            quote! {
81                #c,
82            }
83        })
84        .collect::<Vec<_>>();
85
86    let pad_token = match pad {
87        true => quote! {
88            output[out_ptr..].fill(b'=');
89        },
90        false => quote! {},
91    };
92
93    let encode_len_tokens = match pad {
94        true => quote! {
95            (n + 2) / 3 * 4
96        },
97        false => quote! {
98            n / 3 * 4 + (n % 3 * 4 + 2) / 3
99        },
100    };
101
102    let decode_len_tokens = match pad {
103        true => quote! {
104            (n / 4) * 3
105        },
106        false => quote! {
107            (n * 3) / 4
108        },
109    };
110
111    let use_std_tokens = match use_std {
112        true => quote! {
113            pub fn encode(input: &[u8]) -> String {
114                let mut output = vec![0; Self::encode_len(input.len())];
115                Self::encode_slice(input, &mut output);
116
117                String::from_utf8(output).expect("Base64 utf8 error")
118            }
119
120            pub fn decode(input: &str) -> Vec<u8> {
121                let mut output = vec![0; Self::decode_len(input.len())];
122                let n = Self::decode_slice(input.as_bytes(), &mut output);
123
124                output[..n].to_vec()
125            }
126        },
127        false => quote! {},
128    };
129
130    let encode_map_len = encode_map.len();
131    quote! {
132        pub struct #struct_name;
133        impl #struct_name {
134            const ENCODE_MAP: [char; #encode_map_len] = [
135                #(#encode_map)*
136            ];
137
138            const DECODE_MAP: [u8; 255] = [
139                #(#decode_map)*
140            ];
141
142
143            pub fn encode_slice(input: &[u8], output: &mut [u8]) {
144                if Self::encode_len(input.len()) > output.len() {
145                    panic!("Output buffer too small!!! TODO: Make this as result, not as a panic LMAO");
146                }
147
148                // stack
149                let mut bit_size = 0 as usize;
150                let mut bit_stack = 0 as u64;
151
152                let mut out_ptr = 0;
153                for byte in input {
154                    bit_stack <<= 8;
155                    bit_stack |= *byte as u64;
156                    bit_size += 8;
157
158                    if bit_size == 24 {
159                        output[out_ptr + 0] = Self::ENCODE_MAP[((bit_stack & 0b111111000000000000000000) >> 18) as usize] as u8;
160                        output[out_ptr + 1] = Self::ENCODE_MAP[((bit_stack & 0b111111000000000000) >> 12) as usize] as u8;
161                        output[out_ptr + 2] = Self::ENCODE_MAP[((bit_stack & 0b111111000000) >> 6) as usize] as u8;
162                        output[out_ptr + 3] = Self::ENCODE_MAP[(bit_stack & 0b111111) as usize] as u8;
163
164                        out_ptr += 4;
165                        bit_size = 0;
166                    }
167                }
168
169                // align bits to 6's
170                let to_align = 6 - (bit_size % 6);
171                bit_stack <<= to_align;
172                bit_size += to_align;
173
174                let mut pad_len = 4;
175                while bit_size > 0 {
176                    let shift = bit_size - 6;
177                    output[out_ptr] = Self::ENCODE_MAP[((bit_stack & (0b111111 << shift)) >> shift) as usize] as u8;
178                    bit_size -= 6;
179                    pad_len -= 1;
180                    out_ptr += 1;
181                }
182
183                #pad_token
184            }
185
186            pub fn decode_slice(input: &[u8], output: &mut [u8]) -> usize {
187                let mut out_ptr = 0;
188
189                // stack
190                let mut bit_stack = 0 as u64;
191                let mut bit_size = 0usize;
192
193                for &c in input {
194                    if c == b'=' {
195                        break;
196                    }
197
198                    let val = Self::DECODE_MAP[c as usize];
199                    if val == 64 {
200                        panic!("Wrong base64 character! ({:?})", c);
201                    }
202
203                    bit_stack <<= 6;
204                    bit_stack |= val as u64;
205                    bit_size += 6;
206
207                    if bit_size >= 8 {
208                        let shift = bit_size - 8;
209                        let byte = ((bit_stack & (0b11111111 << shift)) >> shift) as u8;
210                        bit_size -= 8;
211
212                        output[out_ptr] = byte;
213                        out_ptr += 1;
214                    }
215                }
216
217                out_ptr
218            }
219
220            #use_std_tokens
221
222            pub const fn encode_len(n: usize) -> usize {
223                #encode_len_tokens
224            }
225
226            pub const fn decode_len(n: usize) -> usize {
227                #decode_len_tokens
228            }
229        }
230    }
231    .to_token_stream()
232    .into()
233}