Skip to main content

link_section_proc_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::iter::FromIterator;
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6
7mod xx3;
8
9#[allow(missing_docs)]
10#[proc_macro_attribute]
11pub fn in_section(attribute: TokenStream, item: TokenStream) -> TokenStream {
12    generate("in_section", "link_section", attribute, item)
13}
14
15#[allow(missing_docs)]
16#[proc_macro_attribute]
17pub fn section(attribute: TokenStream, item: TokenStream) -> TokenStream {
18    generate("section", "link_section", attribute, item)
19}
20
21fn decode_literal_string(name: &str, literal: Literal) -> String {
22    let literal = literal.to_string();
23    let Some(literal) = literal.strip_prefix('"') else {
24        panic!("{}: Expected a literal string", name);
25    };
26    let Some(literal) = literal.strip_suffix('"') else {
27        panic!("{}: Expected a literal string", name);
28    };
29    if !literal.contains('\\') {
30        literal.to_string()
31    } else {
32        let mut output = String::with_capacity(literal.len());
33        let mut iter = literal.chars();
34        while let Some(c) = iter.next() {
35            if c == '\\' {
36                match iter.next() {
37                    Some('n') => output.push('\n'),
38                    Some('r') => output.push('\r'),
39                    Some('t') => output.push('\t'),
40                    Some('\\') => output.push('\\'),
41                    Some('"') => output.push('"'),
42                    Some('\'') => output.push('\''),
43                    Some('0') => output.push('\0'),
44                    Some('x') => {
45                        let Some(c) = iter.next() else {
46                            panic!("{}: Expected a hexadecimal character", name);
47                        };
48                        let Some(c2) = iter.next() else {
49                            panic!("{}: Expected a hexadecimal character", name);
50                        };
51                        let Ok(c) = format!("{}{}", c, c2).parse::<u8>() else {
52                            panic!("{}: Expected a hexadecimal character", name);
53                        };
54                        output.push(char::from(c));
55                    }
56                    Some(_) => panic!("{}: Expected a valid escape sequence", name),
57                    None => break,
58                }
59            } else {
60                output.push(c);
61            }
62        }
63        output
64    }
65}
66
67fn decode_literal_strings(name: &str, item: TokenTree) -> String {
68    let mut output = String::new();
69    match item {
70        TokenTree::Literal(literal) => {
71            output.push_str(&decode_literal_string(name, literal));
72        }
73        TokenTree::Group(group) => {
74            for token in group.stream().into_iter() {
75                output.push_str(&decode_literal_strings(name, token));
76            }
77        }
78        TokenTree::Punct(_) => {
79            // Ignore punctuation
80        }
81        _ => {
82            panic!("{}: Expected a literal string or group, got `{item}`", name);
83        }
84    }
85    output
86}
87
88fn expect_literal(name: &str, item: TokenTree) -> Literal {
89    match item {
90        TokenTree::Literal(literal) => literal,
91        TokenTree::Group(group) => {
92            if group.delimiter() != Delimiter::None {
93                panic!(
94                    "{}: Expected a single literal, got `{:?}` group",
95                    name,
96                    group.delimiter()
97                );
98            }
99            let tokens = group.stream().into_iter().collect::<Vec<_>>();
100            if tokens.len() != 1 {
101                panic!(
102                    "{}: Expected a single literal, got `{}`",
103                    name,
104                    tokens.len()
105                );
106            }
107            expect_literal(name, tokens.into_iter().next().unwrap())
108        }
109        token => {
110            panic!("{}: Expected a literal, got `{token}`", name);
111        }
112    }
113}
114
115fn expect_numeric_literal(name: &str, item: TokenTree) -> usize {
116    let literal = expect_literal(name, item).to_string();
117    let Ok(literal) = literal.parse::<usize>() else {
118        panic!("{}: Expected a literal integer, got `{literal}`", name);
119    };
120    literal
121}
122
123/// Concatenate two identifiers.
124#[proc_macro]
125pub fn ident_concat(item: TokenStream) -> TokenStream {
126    let mut item = item.into_iter();
127    let Some(TokenTree::Group(pre_group)) = item.next() else {
128        panic!("pre_group: Expected a group");
129    };
130    let Some(TokenTree::Group(name_group)) = item.next() else {
131        panic!("name_group: Expected a group");
132    };
133    let Some(TokenTree::Group(post_group)) = item.next() else {
134        panic!("post_group: Expected a group");
135    };
136
137    let mut item = name_group.stream().into_iter();
138    let Some(TokenTree::Ident(ident)) = item.next() else {
139        panic!("ident: Expected an identifier");
140    };
141    let Some(TokenTree::Ident(ident2)) = item.next() else {
142        panic!("ident2: Expected an identifier");
143    };
144
145    let mut output = pre_group.stream();
146    output.extend([TokenTree::Ident(Ident::new(
147        &format!("{ident}{ident2}"),
148        Span::call_site(),
149    ))]);
150    output.extend(post_group.stream());
151    output
152}
153
154/// If the input string is longer than the max length, replace the tail end of
155/// the string with the hash of the string.
156///
157/// hash!(output input (prefix) hash_length max_length valid_section_chars)
158#[proc_macro]
159pub fn hash(item: TokenStream) -> TokenStream {
160    let mut item = item.into_iter();
161
162    let Some(TokenTree::Group(group)) = item.next() else {
163        panic!("output: Expected a group");
164    };
165    let group = group.stream();
166
167    let Some(TokenTree::Ident(literal)) = item.next() else {
168        panic!("input: Expected an identifier");
169    };
170    let literal = literal.to_string();
171
172    let Some(prefix_group) = item.next() else {
173        panic!("prefix: Expected a group");
174    };
175    let prefix = decode_literal_strings("prefix", prefix_group);
176
177    let Some(suffix_group) = item.next() else {
178        panic!("suffix: Expected a group");
179    };
180    let suffix = decode_literal_strings("suffix", suffix_group);
181
182    let hash_length = expect_numeric_literal(
183        "hash_length",
184        item.next().expect("hash_length: Missing argument"),
185    );
186    let max_length = expect_numeric_literal(
187        "max_length",
188        item.next().expect("max_length: Missing argument"),
189    );
190
191    let valid_section_chars = expect_literal(
192        "valid_section_chars",
193        item.next().expect("valid_section_chars: Missing argument"),
194    );
195    let valid_section_chars =
196        decode_literal_string("valid_section_chars", valid_section_chars).into_bytes();
197
198    // If the string is valid as-is, return it
199    let output = if literal.len() < max_length
200        && !literal
201            .to_string()
202            .contains(|c| c > '\u{007f}' || !valid_section_chars.contains(&(c as u8)))
203    {
204        format!("{prefix}{literal}{suffix}")
205    } else {
206        // Not valid, so we need to hash the string
207        let mut output = String::with_capacity(max_length + prefix.len() + suffix.len());
208        output.push_str(&prefix.to_string());
209        let mut next = literal.chars();
210        while output.len() < max_length - hash_length + prefix.len() {
211            let Some(c) = next.next() else {
212                break;
213            };
214            if c <= '\u{007f}' && valid_section_chars.contains(&(c as u8)) {
215                output.push(c);
216            }
217        }
218
219        let mut hash = xx3::xx3hash(&literal);
220        while output.len() < max_length + prefix.len() {
221            let c = valid_section_chars[hash as usize % valid_section_chars.len()];
222            output.push(c as char);
223            hash /= valid_section_chars.len() as u64;
224        }
225        output.push_str(&suffix);
226        output
227    };
228
229    fn emit(tree: TokenStream, output: &str, found: &mut bool) -> TokenStream {
230        if *found {
231            return tree;
232        }
233        let mut stream = TokenStream::new();
234        for input in tree.into_iter() {
235            match input {
236                _ if *found => stream.extend([input]),
237                TokenTree::Ident(ident) if ident.to_string() == "__" => {
238                    stream.extend([TokenTree::Literal(Literal::string(output))]);
239                    *found = true;
240                }
241                TokenTree::Group(group) => stream.extend([TokenTree::Group(Group::new(
242                    group.delimiter(),
243                    emit(group.stream(), output, found),
244                ))]),
245                _ => stream.extend([input]),
246            }
247        }
248        stream
249    }
250
251    let mut found = false;
252    let stream = emit(group, &output, &mut found);
253    if !found {
254        panic!("output: Expected to find __");
255    }
256    TokenStream::from_iter([TokenTree::Group(Group::new(Delimiter::None, stream))])
257}
258
259#[allow(unknown_lints, tail_expr_drop_order)]
260fn generate(
261    macro_type: &str,
262    macro_crate: &str,
263    attribute: TokenStream,
264    item: TokenStream,
265) -> TokenStream {
266    let mut inner = TokenStream::new();
267
268    // Search for crate_path in attributes
269    let mut crate_path = None;
270    let mut tokens = attribute.clone().into_iter().peekable();
271
272    while let Some(token) = tokens.next() {
273        if let TokenTree::Ident(ident) = &token {
274            if ident.to_string() == "crate_path" {
275                // Look for =
276                #[allow(unknown_lints, tail_expr_drop_order)]
277                if let Some(TokenTree::Punct(punct)) = tokens.next() {
278                    if punct.as_char() == '=' {
279                        // Collect tokens until comma or end
280                        let mut path = TokenStream::new();
281                        while let Some(token) = tokens.peek() {
282                            match token {
283                                TokenTree::Punct(p) if p.as_char() == ',' => {
284                                    tokens.next();
285                                    break;
286                                }
287                                _ => {
288                                    path.extend(std::iter::once(tokens.next().unwrap()));
289                                }
290                            }
291                        }
292                        crate_path = Some(path);
293                        break;
294                    }
295                }
296            }
297        }
298    }
299
300    if attribute.is_empty() {
301        // #[link_section]
302        inner.extend([
303            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
304            TokenTree::Group(Group::new(
305                Delimiter::Bracket,
306                TokenStream::from_iter([TokenTree::Ident(Ident::new(
307                    macro_type,
308                    Span::call_site(),
309                ))]),
310            )),
311        ]);
312    } else {
313        inner.extend([
314            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
315            TokenTree::Group(Group::new(
316                Delimiter::Bracket,
317                TokenStream::from_iter([
318                    TokenTree::Ident(Ident::new(macro_type, Span::call_site())),
319                    TokenTree::Group(Group::new(Delimiter::Parenthesis, attribute)),
320                ]),
321            )),
322        ]);
323    }
324
325    inner.extend(item);
326
327    let mut invoke = crate_path.unwrap_or_else(|| {
328        TokenStream::from_iter([
329            TokenTree::Punct(Punct::new(':', Spacing::Joint)),
330            TokenTree::Punct(Punct::new(':', Spacing::Alone)),
331            TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
332        ])
333    });
334
335    invoke.extend([
336        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
337        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
338        TokenTree::Ident(Ident::new("__support", Span::call_site())),
339        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
340        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
341        TokenTree::Ident(Ident::new(
342            &format!("{macro_type}_parse"),
343            Span::call_site(),
344        )),
345        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
346        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
347        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
348    ]);
349
350    invoke
351}