Skip to main content

link_section_proc_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::iter::FromIterator;
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6
7mod xx3;
8
9#[allow(missing_docs)]
10#[proc_macro_attribute]
11pub fn in_section(attribute: TokenStream, item: TokenStream) -> TokenStream {
12    generate("in_section", "link_section", attribute, item)
13}
14
15#[allow(missing_docs)]
16#[proc_macro_attribute]
17pub fn section(attribute: TokenStream, item: TokenStream) -> TokenStream {
18    generate("section", "link_section", attribute, item)
19}
20
21fn decode_literal_string(name: &str, literal: Literal) -> String {
22    let literal = literal.to_string();
23    let Some(literal) = literal.strip_prefix('"') else {
24        panic!("{}: Expected a literal string", name);
25    };
26    let Some(literal) = literal.strip_suffix('"') else {
27        panic!("{}: Expected a literal string", name);
28    };
29    if !literal.contains('\\') {
30        literal.to_string()
31    } else {
32        let mut output = String::with_capacity(literal.len());
33        let mut iter = literal.chars();
34        while let Some(c) = iter.next() {
35            if c == '\\' {
36                match iter.next() {
37                    Some('n') => output.push('\n'),
38                    Some('r') => output.push('\r'),
39                    Some('t') => output.push('\t'),
40                    Some('\\') => output.push('\\'),
41                    Some('"') => output.push('"'),
42                    Some('\'') => output.push('\''),
43                    Some('0') => output.push('\0'),
44                    Some('x') => {
45                        let Some(c) = iter.next() else {
46                            panic!("{}: Expected a hexadecimal character", name);
47                        };
48                        let Some(c2) = iter.next() else {
49                            panic!("{}: Expected a hexadecimal character", name);
50                        };
51                        let Ok(c) = format!("{}{}", c, c2).parse::<u8>() else {
52                            panic!("{}: Expected a hexadecimal character", name);
53                        };
54                        output.push(char::from(c));
55                    }
56                    Some(_) => panic!("{}: Expected a valid escape sequence", name),
57                    None => break,
58                }
59            } else {
60                output.push(c);
61            }
62        }
63        output
64    }
65}
66
67fn decode_literal_strings(name: &str, item: TokenTree) -> String {
68    let mut output = String::new();
69    match item {
70        TokenTree::Literal(literal) => {
71            output.push_str(&decode_literal_string(name, literal));
72        }
73        TokenTree::Group(group) => {
74            for token in group.stream().into_iter() {
75                output.push_str(&decode_literal_strings(name, token));
76            }
77        }
78        TokenTree::Punct(_) => {
79            // Ignore punctuation
80        }
81        TokenTree::Ident(ident) => {
82            output.push_str(&ident.to_string());
83        }
84    }
85    output
86}
87
88fn expect_literal(name: &str, item: TokenTree) -> Literal {
89    match item {
90        TokenTree::Literal(literal) => literal,
91        TokenTree::Group(group) => {
92            if group.delimiter() != Delimiter::None {
93                panic!(
94                    "{}: Expected a single literal, got `{:?}` group",
95                    name,
96                    group.delimiter()
97                );
98            }
99            let tokens = group.stream().into_iter().collect::<Vec<_>>();
100            if tokens.len() != 1 {
101                panic!(
102                    "{}: Expected a single literal, got `{}`",
103                    name,
104                    tokens.len()
105                );
106            }
107            expect_literal(name, tokens.into_iter().next().unwrap())
108        }
109        token => {
110            panic!("{}: Expected a literal, got `{token}`", name);
111        }
112    }
113}
114
115fn expect_numeric_literal(name: &str, item: TokenTree) -> usize {
116    let literal = expect_literal(name, item).to_string();
117    let Ok(literal) = literal.parse::<usize>() else {
118        panic!("{}: Expected a literal integer, got `{literal}`", name);
119    };
120    literal
121}
122
123/// Concatenate two identifiers.
124#[proc_macro]
125pub fn ident_concat(item: TokenStream) -> TokenStream {
126    let mut item = item.into_iter();
127    let Some(TokenTree::Group(pre_group)) = item.next() else {
128        panic!("pre_group: Expected a group");
129    };
130    let Some(TokenTree::Group(name_group)) = item.next() else {
131        panic!("name_group: Expected a group");
132    };
133    let Some(TokenTree::Group(post_group)) = item.next() else {
134        panic!("post_group: Expected a group");
135    };
136
137    let mut item = name_group.stream().into_iter();
138    let mut name = String::new();
139    while let Some(TokenTree::Ident(ident)) = item.next() {
140        name.push_str(&ident.to_string());
141    }
142
143    let mut output = pre_group.stream();
144    output.extend([TokenTree::Ident(Ident::new(
145        &name,
146        Span::call_site(),
147    ))]);
148    output.extend(post_group.stream());
149    output
150}
151
152/// If the input string is longer than the max length, replace the tail end of
153/// the string with the hash of the string.
154///
155/// hash!(output (prefix) (name) (suffix) hash_length max_length valid_section_chars)
156#[proc_macro]
157pub fn hash(item: TokenStream) -> TokenStream {
158    let mut item = item.into_iter();
159
160    let Some(TokenTree::Group(group)) = item.next() else {
161        panic!("output: Expected a group");
162    };
163    let group = group.stream();
164
165    let Some(prefix_group) = item.next() else {
166        panic!("prefix: Expected a group");
167    };
168    let prefix = decode_literal_strings("prefix", prefix_group);
169
170    let Some(input_group) = item.next() else {
171        panic!("input: Expected an identifier");
172    };
173    let literal = decode_literal_strings("input", input_group);
174
175    let Some(suffix_group) = item.next() else {
176        panic!("suffix: Expected a group");
177    };
178    let suffix = decode_literal_strings("suffix", suffix_group);
179
180    let hash_length = expect_numeric_literal(
181        "hash_length",
182        item.next().expect("hash_length: Missing argument"),
183    );
184    let max_length = expect_numeric_literal(
185        "max_length",
186        item.next().expect("max_length: Missing argument"),
187    );
188
189    let valid_section_chars = expect_literal(
190        "valid_section_chars",
191        item.next().expect("valid_section_chars: Missing argument"),
192    );
193    let valid_section_chars =
194        decode_literal_string("valid_section_chars", valid_section_chars).into_bytes();
195
196    // If the string is valid as-is, return it
197    let output = if literal.len() < max_length
198        && !literal
199            .to_string()
200            .contains(|c| c > '\u{007f}' || !valid_section_chars.contains(&(c as u8)))
201    {
202        format!("{prefix}{literal}{suffix}")
203    } else {
204        // Not valid, so we need to hash the string
205        let mut output = String::with_capacity(max_length + prefix.len() + suffix.len());
206        output.push_str(&prefix.to_string());
207        let mut next = literal.chars();
208        while output.len() < max_length - hash_length + prefix.len() {
209            let Some(c) = next.next() else {
210                break;
211            };
212            if c <= '\u{007f}' && valid_section_chars.contains(&(c as u8)) {
213                output.push(c);
214            }
215        }
216
217        let mut hash = xx3::xx3hash(&literal);
218        while output.len() < max_length + prefix.len() {
219            let c = valid_section_chars[hash as usize % valid_section_chars.len()];
220            output.push(c as char);
221            hash /= valid_section_chars.len() as u64;
222        }
223        output.push_str(&suffix);
224        output
225    };
226
227    fn emit(tree: TokenStream, output: &str, found: &mut bool) -> TokenStream {
228        if *found {
229            return tree;
230        }
231        let mut stream = TokenStream::new();
232        for input in tree.into_iter() {
233            match input {
234                _ if *found => stream.extend([input]),
235                TokenTree::Ident(ident) if ident.to_string() == "__" => {
236                    stream.extend([TokenTree::Literal(Literal::string(output))]);
237                    *found = true;
238                }
239                TokenTree::Group(group) => stream.extend([TokenTree::Group(Group::new(
240                    group.delimiter(),
241                    emit(group.stream(), output, found),
242                ))]),
243                _ => stream.extend([input]),
244            }
245        }
246        stream
247    }
248
249    let mut found = false;
250    let stream = emit(group, &output, &mut found);
251    if !found {
252        panic!("output: Expected to find __");
253    }
254    TokenStream::from_iter([TokenTree::Group(Group::new(Delimiter::None, stream))])
255}
256
257#[allow(unknown_lints, tail_expr_drop_order)]
258fn generate(
259    macro_type: &str,
260    macro_crate: &str,
261    attribute: TokenStream,
262    item: TokenStream,
263) -> TokenStream {
264    let mut inner = TokenStream::new();
265
266    // Search for crate_path in attributes
267    let mut crate_path = None;
268    let mut tokens = attribute.clone().into_iter().peekable();
269
270    while let Some(token) = tokens.next() {
271        if let TokenTree::Ident(ident) = &token {
272            if ident.to_string() == "crate_path" {
273                // Look for =
274                #[allow(unknown_lints, tail_expr_drop_order)]
275                if let Some(TokenTree::Punct(punct)) = tokens.next() {
276                    if punct.as_char() == '=' {
277                        // Collect tokens until comma or end
278                        let mut path = TokenStream::new();
279                        while let Some(token) = tokens.peek() {
280                            match token {
281                                TokenTree::Punct(p) if p.as_char() == ',' => {
282                                    tokens.next();
283                                    break;
284                                }
285                                _ => {
286                                    path.extend(std::iter::once(tokens.next().unwrap()));
287                                }
288                            }
289                        }
290                        crate_path = Some(path);
291                        break;
292                    }
293                }
294            }
295        }
296    }
297
298    if attribute.is_empty() {
299        // #[link_section]
300        inner.extend([
301            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
302            TokenTree::Group(Group::new(
303                Delimiter::Bracket,
304                TokenStream::from_iter([TokenTree::Ident(Ident::new(
305                    macro_type,
306                    Span::call_site(),
307                ))]),
308            )),
309        ]);
310    } else {
311        inner.extend([
312            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
313            TokenTree::Group(Group::new(
314                Delimiter::Bracket,
315                TokenStream::from_iter([
316                    TokenTree::Ident(Ident::new(macro_type, Span::call_site())),
317                    TokenTree::Group(Group::new(Delimiter::Parenthesis, attribute)),
318                ]),
319            )),
320        ]);
321    }
322
323    inner.extend(item);
324
325    let mut invoke = crate_path.unwrap_or_else(|| {
326        TokenStream::from_iter([
327            TokenTree::Punct(Punct::new(':', Spacing::Joint)),
328            TokenTree::Punct(Punct::new(':', Spacing::Alone)),
329            TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
330        ])
331    });
332
333    invoke.extend([
334        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
335        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
336        TokenTree::Ident(Ident::new("__support", Span::call_site())),
337        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
338        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
339        TokenTree::Ident(Ident::new(
340            &format!("{macro_type}_parse"),
341            Span::call_site(),
342        )),
343        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
344        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
345        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
346    ]);
347
348    invoke
349}