1use itertools::Itertools;
2use proc_macro::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{parse::Parser, punctuated::Punctuated, Expr, Lit, Token};
5
6#[proc_macro]
7pub fn base64_impl(item: TokenStream) -> TokenStream {
8 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
9 let args = parser.parse(item).unwrap();
10 if args.len() != 4 {
11 panic!("This macro requires 3 arguments (structName, \"CHARSET\", padding, std)")
12 }
13
14 let struct_name = if let Expr::Path(struct_name_expr) = args.get(0).unwrap() {
15 let segments = struct_name_expr.path.segments.iter().collect::<Vec<_>>();
16
17 if segments.len() != 1 {
18 panic!("First argument should be simple struct name (one segment)");
19 }
20
21 segments[0].ident.clone()
22 } else {
23 panic!("First argument not a Expr::Lit!");
24 };
25
26 let charset = if let Expr::Lit(charset_expr) = args.get(1).unwrap() {
27 if let Lit::Str(charset_str) = &charset_expr.lit {
28 charset_str.value()
29 } else {
30 panic!("Second argument not a string!");
31 }
32 } else {
33 panic!("Second argument not a Expr::Lit!");
34 };
35
36 let pad = if let Expr::Lit(pad_expr) = args.get(2).unwrap() {
37 if let Lit::Bool(pad_val) = &pad_expr.lit {
38 pad_val.value()
39 } else {
40 panic!("Third argument not a bool!");
41 }
42 } else {
43 panic!("Third argument not a Expr::Lit!");
44 };
45
46 let use_std = if let Expr::Lit(std_expr) = args.get(3).unwrap() {
47 if let Lit::Bool(std_val) = &std_expr.lit {
48 std_val.value()
49 } else {
50 panic!("Fourth argument not a bool!");
51 }
52 } else {
53 panic!("Fourth argument not a Expr::Lit!");
54 };
55
56 let encode_map = charset.chars().collect::<Vec<_>>();
57 let copied_encode_map = encode_map.clone().into_iter().unique().collect::<Vec<_>>();
58 if encode_map.len() != copied_encode_map.len() {
59 panic!("Characters cannot contain duplicates!");
60 }
61
62 let mut decode_map = vec![255; 255];
63 for i in 0..encode_map.len() {
64 let char_val = encode_map[i] as u8;
65 decode_map[char_val as usize] = i as u8;
66 }
67
68 let encode_map = encode_map
69 .iter()
70 .map(|c| {
71 quote! {
72 #c,
73 }
74 })
75 .collect::<Vec<_>>();
76
77 let decode_map = decode_map
78 .iter()
79 .map(|c| {
80 quote! {
81 #c,
82 }
83 })
84 .collect::<Vec<_>>();
85
86 let pad_token = match pad {
87 true => quote! {
88 output[out_ptr..].fill(b'=');
89 },
90 false => quote! {},
91 };
92
93 let encode_len_tokens = match pad {
94 true => quote! {
95 (n + 2) / 3 * 4
96 },
97 false => quote! {
98 n / 3 * 4 + (n % 3 * 4 + 2) / 3
99 },
100 };
101
102 let decode_len_tokens = match pad {
103 true => quote! {
104 (n / 4) * 3
105 },
106 false => quote! {
107 (n * 3) / 4
108 },
109 };
110
111 let use_std_tokens = match use_std {
112 true => quote! {
113 pub fn encode(input: &[u8]) -> String {
114 let mut output = vec![0; Self::encode_len(input.len())];
115 Self::encode_slice(input, &mut output);
116
117 String::from_utf8(output).expect("Base64 utf8 error")
118 }
119
120 pub fn decode(input: &str) -> Vec<u8> {
121 let mut output = vec![0; Self::decode_len(input.len())];
122 let n = Self::decode_slice(input.as_bytes(), &mut output);
123
124 output[..n].to_vec()
125 }
126 },
127 false => quote! {},
128 };
129
130 let encode_map_len = encode_map.len();
131 quote! {
132 pub struct #struct_name;
133 impl #struct_name {
134 const ENCODE_MAP: [char; #encode_map_len] = [
135 #(#encode_map)*
136 ];
137
138 const DECODE_MAP: [u8; 255] = [
139 #(#decode_map)*
140 ];
141
142
143 pub fn encode_slice(input: &[u8], output: &mut [u8]) {
144 if Self::encode_len(input.len()) > output.len() {
145 panic!("Output buffer too small!!! TODO: Make this as result, not as a panic LMAO");
146 }
147
148 let mut bit_size = 0 as usize;
150 let mut bit_stack = 0 as u64;
151
152 let mut out_ptr = 0;
153 for byte in input {
154 bit_stack <<= 8;
155 bit_stack |= *byte as u64;
156 bit_size += 8;
157
158 if bit_size == 24 {
159 output[out_ptr + 0] = Self::ENCODE_MAP[((bit_stack & 0b111111000000000000000000) >> 18) as usize] as u8;
160 output[out_ptr + 1] = Self::ENCODE_MAP[((bit_stack & 0b111111000000000000) >> 12) as usize] as u8;
161 output[out_ptr + 2] = Self::ENCODE_MAP[((bit_stack & 0b111111000000) >> 6) as usize] as u8;
162 output[out_ptr + 3] = Self::ENCODE_MAP[(bit_stack & 0b111111) as usize] as u8;
163
164 out_ptr += 4;
165 bit_size = 0;
166 }
167 }
168
169 let to_align = 6 - (bit_size % 6);
171 bit_stack <<= to_align;
172 bit_size += to_align;
173
174 let mut pad_len = 4;
175 while bit_size > 0 {
176 let shift = bit_size - 6;
177 output[out_ptr] = Self::ENCODE_MAP[((bit_stack & (0b111111 << shift)) >> shift) as usize] as u8;
178 bit_size -= 6;
179 pad_len -= 1;
180 out_ptr += 1;
181 }
182
183 #pad_token
184 }
185
186 pub fn decode_slice(input: &[u8], output: &mut [u8]) -> usize {
187 let mut out_ptr = 0;
188
189 let mut bit_stack = 0 as u64;
191 let mut bit_size = 0usize;
192
193 for &c in input {
194 if c == b'=' {
195 break;
196 }
197
198 let val = Self::DECODE_MAP[c as usize];
199 if val == 64 {
200 panic!("Wrong base64 character! ({:?})", c);
201 }
202
203 bit_stack <<= 6;
204 bit_stack |= val as u64;
205 bit_size += 6;
206
207 if bit_size >= 8 {
208 let shift = bit_size - 8;
209 let byte = ((bit_stack & (0b11111111 << shift)) >> shift) as u8;
210 bit_size -= 8;
211
212 output[out_ptr] = byte;
213 out_ptr += 1;
214 }
215 }
216
217 out_ptr
218 }
219
220 #use_std_tokens
221
222 pub const fn encode_len(n: usize) -> usize {
223 #encode_len_tokens
224 }
225
226 pub const fn decode_len(n: usize) -> usize {
227 #decode_len_tokens
228 }
229 }
230 }
231 .to_token_stream()
232 .into()
233}