Skip to main content

link_section_proc_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::iter::FromIterator;
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6
7mod xx3;
8
9#[allow(missing_docs)]
10#[proc_macro_attribute]
11pub fn in_section(attribute: TokenStream, item: TokenStream) -> TokenStream {
12    generate("in_section", "link_section", attribute, item)
13}
14
15#[allow(missing_docs)]
16#[proc_macro_attribute]
17pub fn section(attribute: TokenStream, item: TokenStream) -> TokenStream {
18    generate("section", "link_section", attribute, item)
19}
20
21fn decode_literal_string(name: &str, literal: Literal) -> String {
22    let literal = literal.to_string();
23    let Some(literal) = literal.strip_prefix('"') else {
24        panic!("{}: Expected a literal string", name);
25    };
26    let Some(literal) = literal.strip_suffix('"') else {
27        panic!("{}: Expected a literal string", name);
28    };
29    if !literal.contains('\\') {
30        literal.to_string()
31    } else {
32        let mut output = String::with_capacity(literal.len());
33        let mut iter = literal.chars();
34        while let Some(c) = iter.next() {
35            if c == '\\' {
36                match iter.next() {
37                    Some('n') => output.push('\n'),
38                    Some('r') => output.push('\r'),
39                    Some('t') => output.push('\t'),
40                    Some('\\') => output.push('\\'),
41                    Some('"') => output.push('"'),
42                    Some('\'') => output.push('\''),
43                    Some('0') => output.push('\0'),
44                    Some('x') => {
45                        let Some(c) = iter.next() else {
46                            panic!("{}: Expected a hexadecimal character", name);
47                        };
48                        let Some(c2) = iter.next() else {
49                            panic!("{}: Expected a hexadecimal character", name);
50                        };
51                        let Ok(c) = format!("{}{}", c, c2).parse::<u8>() else {
52                            panic!("{}: Expected a hexadecimal character", name);
53                        };
54                        output.push(char::from(c));
55                    }
56                    Some(_) => panic!("{}: Expected a valid escape sequence", name),
57                    None => break,
58                }
59            } else {
60                output.push(c);
61            }
62        }
63        output
64    }
65}
66
67fn decode_literal_strings(name: &str, item: TokenTree) -> String {
68    let mut output = String::new();
69    match item {
70        TokenTree::Literal(literal) => {
71            output.push_str(&decode_literal_string(name, literal));
72        }
73        TokenTree::Group(group) => {
74            for token in group.stream().into_iter() {
75                output.push_str(&decode_literal_strings(name, token));
76            }
77        }
78        TokenTree::Punct(_) => {
79            // Ignore punctuation
80        }
81        TokenTree::Ident(ident) => {
82            output.push_str(&ident.to_string());
83        }
84    }
85    output
86}
87
88fn expect_literal(name: &str, item: TokenTree) -> Literal {
89    match item {
90        TokenTree::Literal(literal) => literal,
91        TokenTree::Group(group) => {
92            if group.delimiter() != Delimiter::None {
93                panic!(
94                    "{}: Expected a single literal, got `{:?}` group",
95                    name,
96                    group.delimiter()
97                );
98            }
99            let tokens = group.stream().into_iter().collect::<Vec<_>>();
100            if tokens.len() != 1 {
101                panic!(
102                    "{}: Expected a single literal, got `{}`",
103                    name,
104                    tokens.len()
105                );
106            }
107            expect_literal(name, tokens.into_iter().next().unwrap())
108        }
109        token => {
110            panic!("{}: Expected a literal, got `{token}`", name);
111        }
112    }
113}
114
115fn expect_numeric_literal(name: &str, item: TokenTree) -> usize {
116    let literal = expect_literal(name, item).to_string();
117    let Ok(literal) = literal.parse::<usize>() else {
118        panic!("{}: Expected a literal integer, got `{literal}`", name);
119    };
120    literal
121}
122
123/// Concatenate two identifiers.
124#[proc_macro]
125pub fn ident_concat(item: TokenStream) -> TokenStream {
126    let mut item = item.into_iter();
127    let Some(TokenTree::Group(pre_group)) = item.next() else {
128        panic!("pre_group: Expected a group");
129    };
130    let Some(TokenTree::Group(name_group)) = item.next() else {
131        panic!("name_group: Expected a group");
132    };
133    let Some(TokenTree::Group(post_group)) = item.next() else {
134        panic!("post_group: Expected a group");
135    };
136
137    let mut item = name_group.stream().into_iter();
138    let mut name = String::new();
139    while let Some(TokenTree::Ident(ident)) = item.next() {
140        name.push_str(&ident.to_string());
141    }
142
143    let mut output = pre_group.stream();
144    output.extend([TokenTree::Ident(Ident::new(&name, Span::call_site()))]);
145    output.extend(post_group.stream());
146    output
147}
148
149/// If the input string is longer than the max length, replace the tail end of
150/// the string with the hash of the string.
151///
152/// hash!(output (prefix) (name) (suffix) hash_length max_length valid_section_chars)
153#[proc_macro]
154pub fn hash(item: TokenStream) -> TokenStream {
155    let mut item = item.into_iter();
156
157    let Some(TokenTree::Group(group)) = item.next() else {
158        panic!("output: Expected a group");
159    };
160    let group = group.stream();
161
162    let Some(prefix_group) = item.next() else {
163        panic!("prefix: Expected a group");
164    };
165    let prefix = decode_literal_strings("prefix", prefix_group);
166
167    let Some(input_group) = item.next() else {
168        panic!("input: Expected an identifier");
169    };
170    let literal = decode_literal_strings("input", input_group);
171
172    let Some(suffix_group) = item.next() else {
173        panic!("suffix: Expected a group");
174    };
175    let suffix = decode_literal_strings("suffix", suffix_group);
176
177    let hash_length = expect_numeric_literal(
178        "hash_length",
179        item.next().expect("hash_length: Missing argument"),
180    );
181    let max_length = expect_numeric_literal(
182        "max_length",
183        item.next().expect("max_length: Missing argument"),
184    );
185
186    let valid_section_chars = expect_literal(
187        "valid_section_chars",
188        item.next().expect("valid_section_chars: Missing argument"),
189    );
190    let valid_section_chars =
191        decode_literal_string("valid_section_chars", valid_section_chars).into_bytes();
192
193    // If the string is valid as-is, return it
194    let output = if literal.len() < max_length
195        && !literal
196            .to_string()
197            .contains(|c| c > '\u{007f}' || !valid_section_chars.contains(&(c as u8)))
198    {
199        format!("{prefix}{literal}{suffix}")
200    } else {
201        // Not valid, so we need to hash the string
202        let mut output = String::with_capacity(max_length + prefix.len() + suffix.len());
203        output.push_str(&prefix.to_string());
204        let mut next = literal.chars();
205        while output.len() < max_length - hash_length + prefix.len() {
206            let Some(c) = next.next() else {
207                break;
208            };
209            if c <= '\u{007f}' && valid_section_chars.contains(&(c as u8)) {
210                output.push(c);
211            }
212        }
213
214        let mut hash = xx3::xx3hash(&literal);
215        while output.len() < max_length + prefix.len() {
216            let c = valid_section_chars[hash as usize % valid_section_chars.len()];
217            output.push(c as char);
218            hash /= valid_section_chars.len() as u64;
219        }
220        output.push_str(&suffix);
221        output
222    };
223
224    fn emit(tree: TokenStream, output: &str, found: &mut bool) -> TokenStream {
225        if *found {
226            return tree;
227        }
228        let mut stream = TokenStream::new();
229        for input in tree.into_iter() {
230            match input {
231                _ if *found => stream.extend([input]),
232                TokenTree::Ident(ident) if ident.to_string() == "__" => {
233                    stream.extend([TokenTree::Literal(Literal::string(output))]);
234                    *found = true;
235                }
236                TokenTree::Group(group) => stream.extend([TokenTree::Group(Group::new(
237                    group.delimiter(),
238                    emit(group.stream(), output, found),
239                ))]),
240                _ => stream.extend([input]),
241            }
242        }
243        stream
244    }
245
246    let mut found = false;
247    let stream = emit(group, &output, &mut found);
248    if !found {
249        panic!("output: Expected to find __");
250    }
251    TokenStream::from_iter([TokenTree::Group(Group::new(Delimiter::None, stream))])
252}
253
254#[allow(unknown_lints, tail_expr_drop_order)]
255fn generate(
256    macro_type: &str,
257    macro_crate: &str,
258    attribute: TokenStream,
259    item: TokenStream,
260) -> TokenStream {
261    let mut inner = TokenStream::new();
262
263    // Search for crate_path in attributes
264    let mut crate_path = None;
265    let mut tokens = attribute.clone().into_iter().peekable();
266
267    while let Some(token) = tokens.next() {
268        if let TokenTree::Ident(ident) = &token {
269            if ident.to_string() == "crate_path" {
270                // Look for =
271                #[allow(unknown_lints, tail_expr_drop_order)]
272                if let Some(TokenTree::Punct(punct)) = tokens.next() {
273                    if punct.as_char() == '=' {
274                        // Collect tokens until comma or end
275                        let mut path = TokenStream::new();
276                        while let Some(token) = tokens.peek() {
277                            match token {
278                                TokenTree::Punct(p) if p.as_char() == ',' => {
279                                    tokens.next();
280                                    break;
281                                }
282                                _ => {
283                                    path.extend(std::iter::once(tokens.next().unwrap()));
284                                }
285                            }
286                        }
287                        crate_path = Some(path);
288                        break;
289                    }
290                }
291            }
292        }
293    }
294
295    if attribute.is_empty() {
296        // #[link_section]
297        inner.extend([
298            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
299            TokenTree::Group(Group::new(
300                Delimiter::Bracket,
301                TokenStream::from_iter([TokenTree::Ident(Ident::new(
302                    macro_type,
303                    Span::call_site(),
304                ))]),
305            )),
306        ]);
307    } else {
308        inner.extend([
309            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
310            TokenTree::Group(Group::new(
311                Delimiter::Bracket,
312                TokenStream::from_iter([
313                    TokenTree::Ident(Ident::new(macro_type, Span::call_site())),
314                    TokenTree::Group(Group::new(Delimiter::Parenthesis, attribute)),
315                ]),
316            )),
317        ]);
318    }
319
320    inner.extend(item);
321
322    let mut invoke = crate_path.unwrap_or_else(|| {
323        TokenStream::from_iter([
324            TokenTree::Punct(Punct::new(':', Spacing::Joint)),
325            TokenTree::Punct(Punct::new(':', Spacing::Alone)),
326            TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
327        ])
328    });
329
330    invoke.extend([
331        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
332        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
333        TokenTree::Ident(Ident::new("__support", Span::call_site())),
334        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
335        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
336        TokenTree::Ident(Ident::new(
337            &format!("{macro_type}_parse"),
338            Span::call_site(),
339        )),
340        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
341        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
342        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
343    ]);
344
345    invoke
346}