Skip to main content

link_section_proc_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::iter::FromIterator;
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6
7mod xx3;
8
9#[allow(missing_docs)]
10#[proc_macro_attribute]
11pub fn in_section(attribute: TokenStream, item: TokenStream) -> TokenStream {
12    generate("in_section", "link_section", attribute, item)
13}
14
15#[allow(missing_docs)]
16#[proc_macro_attribute]
17pub fn section(attribute: TokenStream, item: TokenStream) -> TokenStream {
18    generate("section", "link_section", attribute, item)
19}
20
21fn decode_literal_string(name: &str, literal: Literal) -> String {
22    let literal = literal.to_string();
23    let Some(literal) = literal.strip_prefix('"') else {
24        panic!("{}: Expected a literal string", name);
25    };
26    let Some(literal) = literal.strip_suffix('"') else {
27        panic!("{}: Expected a literal string", name);
28    };
29    if !literal.contains('\\') {
30        literal.to_string()
31    } else {
32        let mut output = String::with_capacity(literal.len());
33        let mut iter = literal.chars();
34        while let Some(c) = iter.next() {
35            if c == '\\' {
36                match iter.next() {
37                    Some('n') => output.push('\n'),
38                    Some('r') => output.push('\r'),
39                    Some('t') => output.push('\t'),
40                    Some('\\') => output.push('\\'),
41                    Some('"') => output.push('"'),
42                    Some('\'') => output.push('\''),
43                    Some('0') => output.push('\0'),
44                    Some('x') => {
45                        let Some(c) = iter.next() else {
46                            panic!("{}: Expected a hexadecimal character", name);
47                        };
48                        let Some(c2) = iter.next() else {
49                            panic!("{}: Expected a hexadecimal character", name);
50                        };
51                        let Ok(c) = format!("{}{}", c, c2).parse::<u8>() else {
52                            panic!("{}: Expected a hexadecimal character", name);
53                        };
54                        output.push(char::from(c));
55                    }
56                    Some(_) => panic!("{}: Expected a valid escape sequence", name),
57                    None => break,
58                }
59            } else {
60                output.push(c);
61            }
62        }
63        output
64    }
65}
66
67fn decode_literal_strings(name: &str, item: TokenTree) -> String {
68    let mut output = String::new();
69    match item {
70        TokenTree::Literal(literal) => {
71            output.push_str(&decode_literal_string(name, literal));
72        }
73        TokenTree::Group(group) => {
74            for token in group.stream().into_iter() {
75                output.push_str(&decode_literal_strings(name, token));
76            }
77        }
78        _ => {
79            panic!("{}: Expected a literal string or group", name);
80        }
81    }
82    output
83}
84
85/// Concatenate two identifiers.
86#[proc_macro]
87pub fn ident_concat(item: TokenStream) -> TokenStream {
88    let mut item = item.into_iter();
89    let Some(TokenTree::Group(pre_group)) = item.next() else {
90        panic!("pre_group: Expected a group");
91    };
92    let Some(TokenTree::Group(name_group)) = item.next() else {
93        panic!("name_group: Expected a group");
94    };
95    let Some(TokenTree::Group(post_group)) = item.next() else {
96        panic!("post_group: Expected a group");
97    };
98
99    let mut item = name_group.stream().into_iter();
100    let Some(TokenTree::Ident(ident)) = item.next() else {
101        panic!("ident: Expected an identifier");
102    };
103    let Some(TokenTree::Ident(ident2)) = item.next() else {
104        panic!("ident2: Expected an identifier");
105    };
106
107    let mut output = pre_group.stream();
108    output.extend([TokenTree::Ident(Ident::new(
109        &format!("{ident}{ident2}"),
110        Span::call_site(),
111    ))]);
112    output.extend(post_group.stream());
113    output
114}
115
116/// If the input string is longer than the max length, replace the tail end of
117/// the string with the hash of the string.
118///
119/// hash!(output input (prefix) hash_length max_length valid_section_chars)
120#[proc_macro]
121pub fn hash(item: TokenStream) -> TokenStream {
122    let mut item = item.into_iter();
123
124    let Some(TokenTree::Group(group)) = item.next() else {
125        panic!("output: Expected a group");
126    };
127    let group = group.stream();
128
129    let Some(TokenTree::Ident(literal)) = item.next() else {
130        panic!("input: Expected an identifier");
131    };
132    let literal = literal.to_string();
133
134    let Some(prefix_group) = item.next() else {
135        panic!("prefix: Expected a group");
136    };
137    let prefix = decode_literal_strings("prefix", prefix_group);
138
139    let Some(suffix_group) = item.next() else {
140        panic!("suffix: Expected a group");
141    };
142    let suffix = decode_literal_strings("suffix", suffix_group);
143
144    let Some(TokenTree::Literal(hash_length)) = item.next() else {
145        panic!("hash_length: Expected a literal integer");
146    };
147    let Ok(hash_length) = hash_length.to_string().parse::<usize>() else {
148        panic!("hash_length: Expected a literal integer");
149    };
150
151    let Some(TokenTree::Literal(max_length)) = item.next() else {
152        panic!("max_length: Expected a literal integer");
153    };
154    let Ok(max_length) = max_length.to_string().parse::<usize>() else {
155        panic!("max_length: Expected a literal integer");
156    };
157
158    // Valid section chars: "..."
159    let Some(TokenTree::Literal(valid_section_chars)) = item.next() else {
160        panic!("valid_section_chars: Expected a literal string");
161    };
162    let valid_section_chars =
163        decode_literal_string("valid_section_chars", valid_section_chars).into_bytes();
164
165    // If the string is valid as-is, return it
166    let output = if literal.len() < max_length
167        && !literal
168            .to_string()
169            .contains(|c| c > '\u{007f}' || !valid_section_chars.contains(&(c as u8)))
170    {
171        format!("{prefix}{literal}{suffix}")
172    } else {
173        // Not valid, so we need to hash the string
174        let mut output = String::with_capacity(max_length + prefix.len() + suffix.len());
175        output.push_str(&prefix.to_string());
176        let mut next = literal.chars();
177        while output.len() < max_length - hash_length + prefix.len() {
178            let Some(c) = next.next() else {
179                break;
180            };
181            if c <= '\u{007f}' && valid_section_chars.contains(&(c as u8)) {
182                output.push(c);
183            }
184        }
185
186        let mut hash = xx3::xx3hash(&literal);
187        while output.len() < max_length + prefix.len() {
188            let c = valid_section_chars[hash as usize % valid_section_chars.len()];
189            output.push(c as char);
190            hash /= valid_section_chars.len() as u64;
191        }
192        output.push_str(&suffix);
193        output
194    };
195
196    fn emit(tree: TokenStream, output: &str, found: &mut bool) -> TokenStream {
197        if *found {
198            return tree;
199        }
200        let mut stream = TokenStream::new();
201        for input in tree.into_iter() {
202            match input {
203                _ if *found => stream.extend([input]),
204                TokenTree::Ident(ident) if ident.to_string() == "__" => {
205                    stream.extend([TokenTree::Literal(Literal::string(output))]);
206                    *found = true;
207                }
208                TokenTree::Group(group) => stream.extend([TokenTree::Group(Group::new(
209                    group.delimiter(),
210                    emit(group.stream(), output, found),
211                ))]),
212                _ => stream.extend([input]),
213            }
214        }
215        stream
216    }
217
218    let mut found = false;
219    let stream = emit(group, &output, &mut found);
220    if !found {
221        panic!("output: Expected to find __");
222    }
223    TokenStream::from_iter([TokenTree::Group(Group::new(Delimiter::None, stream))])
224}
225
226#[allow(unknown_lints, tail_expr_drop_order)]
227fn generate(
228    macro_type: &str,
229    macro_crate: &str,
230    attribute: TokenStream,
231    item: TokenStream,
232) -> TokenStream {
233    let mut inner = TokenStream::new();
234
235    // Search for crate_path in attributes
236    let mut crate_path = None;
237    let mut tokens = attribute.clone().into_iter().peekable();
238
239    while let Some(token) = tokens.next() {
240        if let TokenTree::Ident(ident) = &token {
241            if ident.to_string() == "crate_path" {
242                // Look for =
243                #[allow(unknown_lints, tail_expr_drop_order)]
244                if let Some(TokenTree::Punct(punct)) = tokens.next() {
245                    if punct.as_char() == '=' {
246                        // Collect tokens until comma or end
247                        let mut path = TokenStream::new();
248                        while let Some(token) = tokens.peek() {
249                            match token {
250                                TokenTree::Punct(p) if p.as_char() == ',' => {
251                                    tokens.next();
252                                    break;
253                                }
254                                _ => {
255                                    path.extend(std::iter::once(tokens.next().unwrap()));
256                                }
257                            }
258                        }
259                        crate_path = Some(path);
260                        break;
261                    }
262                }
263            }
264        }
265    }
266
267    if attribute.is_empty() {
268        // #[link_section]
269        inner.extend([
270            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
271            TokenTree::Group(Group::new(
272                Delimiter::Bracket,
273                TokenStream::from_iter([TokenTree::Ident(Ident::new(
274                    macro_type,
275                    Span::call_site(),
276                ))]),
277            )),
278        ]);
279    } else {
280        inner.extend([
281            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
282            TokenTree::Group(Group::new(
283                Delimiter::Bracket,
284                TokenStream::from_iter([
285                    TokenTree::Ident(Ident::new(macro_type, Span::call_site())),
286                    TokenTree::Group(Group::new(Delimiter::Parenthesis, attribute)),
287                ]),
288            )),
289        ]);
290    }
291
292    inner.extend(item);
293
294    let mut invoke = crate_path.unwrap_or_else(|| {
295        TokenStream::from_iter([
296            TokenTree::Punct(Punct::new(':', Spacing::Joint)),
297            TokenTree::Punct(Punct::new(':', Spacing::Alone)),
298            TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
299        ])
300    });
301
302    invoke.extend([
303        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
304        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
305        TokenTree::Ident(Ident::new("__support", Span::call_site())),
306        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
307        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
308        TokenTree::Ident(Ident::new(
309            &format!("{macro_type}_parse"),
310            Span::call_site(),
311        )),
312        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
313        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
314        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
315    ]);
316
317    invoke
318}