Skip to main content

link_section_proc_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::iter::FromIterator;
4
5use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
6
7mod xx3;
8
9#[allow(missing_docs)]
10#[proc_macro_attribute]
11pub fn in_section(attribute: TokenStream, item: TokenStream) -> TokenStream {
12    generate("in_section", "link_section", attribute, item)
13}
14
15#[allow(missing_docs)]
16#[proc_macro_attribute]
17pub fn section(attribute: TokenStream, item: TokenStream) -> TokenStream {
18    generate("section", "link_section", attribute, item)
19}
20
21fn decode_literal_string(name: &str, literal: Literal) -> String {
22    let literal = literal.to_string();
23    let Some(literal) = literal.strip_prefix('"') else {
24        panic!("{}: Expected a literal string", name);
25    };
26    let Some(literal) = literal.strip_suffix('"') else {
27        panic!("{}: Expected a literal string", name);
28    };
29    if !literal.contains('\\') {
30        literal.to_string()
31    } else {
32        let mut output = String::with_capacity(literal.len());
33        let mut iter = literal.chars();
34        while let Some(c) = iter.next() {
35            if c == '\\' {
36                match iter.next() {
37                    Some('n') => output.push('\n'),
38                    Some('r') => output.push('\r'),
39                    Some('t') => output.push('\t'),
40                    Some('\\') => output.push('\\'),
41                    Some('"') => output.push('"'),
42                    Some('\'') => output.push('\''),
43                    Some('0') => output.push('\0'),
44                    Some('x') => {
45                        let Some(c) = iter.next() else {
46                            panic!("{}: Expected a hexadecimal character", name);
47                        };
48                        let Some(c2) = iter.next() else {
49                            panic!("{}: Expected a hexadecimal character", name);
50                        };
51                        let Ok(c) = format!("{}{}", c, c2).parse::<u8>() else {
52                            panic!("{}: Expected a hexadecimal character", name);
53                        };
54                        output.push(char::from(c));
55                    }
56                    Some(_) => panic!("{}: Expected a valid escape sequence", name),
57                    None => break,
58                }
59            } else {
60                output.push(c);
61            }
62        }
63        output
64    }
65}
66
67fn decode_literal_strings(name: &str, item: TokenTree) -> String {
68    let mut output = String::new();
69    match item {
70        TokenTree::Literal(literal) => {
71            output.push_str(&decode_literal_string(name, literal));
72        }
73        TokenTree::Group(group) => {
74            for token in group.stream().into_iter() {
75                output.push_str(&decode_literal_strings(name, token));
76            }
77        }
78        _ => {
79            panic!("{}: Expected a literal string or group", name);
80        }
81    }
82    output
83}
84
85/// If the input string is longer than the max length, replace the tail end of
86/// the string with the hash of the string.
87///
88/// hash!(output input (prefix) hash_length max_length valid_section_chars)
89#[allow(missing_docs)]
90#[proc_macro]
91pub fn hash(item: TokenStream) -> TokenStream {
92    let mut item = item.into_iter();
93
94    let Some(TokenTree::Group(group)) = item.next() else {
95        panic!("output: Expected a group");
96    };
97    let group = group.stream();
98
99    let Some(TokenTree::Ident(literal)) = item.next() else {
100        panic!("input: Expected an identifier");
101    };
102    let literal = literal.to_string();
103
104    let Some(prefix_group) = item.next() else {
105        panic!("prefix: Expected a group");
106    };
107    let prefix = decode_literal_strings("prefix", prefix_group);
108
109    let Some(suffix_group) = item.next() else {
110        panic!("suffix: Expected a group");
111    };
112    let suffix = decode_literal_strings("suffix", suffix_group);
113
114    let Some(TokenTree::Literal(hash_length)) = item.next() else {
115        panic!("hash_length: Expected a literal integer");
116    };
117    let Ok(hash_length) = hash_length.to_string().parse::<usize>() else {
118        panic!("hash_length: Expected a literal integer");
119    };
120
121    let Some(TokenTree::Literal(max_length)) = item.next() else {
122        panic!("max_length: Expected a literal integer");
123    };
124    let Ok(max_length) = max_length.to_string().parse::<usize>() else {
125        panic!("max_length: Expected a literal integer");
126    };
127
128    // Valid section chars: "..."
129    let Some(TokenTree::Literal(valid_section_chars)) = item.next() else {
130        panic!("valid_section_chars: Expected a literal string");
131    };
132    let valid_section_chars =
133        decode_literal_string("valid_section_chars", valid_section_chars).into_bytes();
134
135    // If the string is valid as-is, return it
136    let output = if literal.len() < max_length
137        && !literal
138            .to_string()
139            .contains(|c| c > '\u{007f}' || !valid_section_chars.contains(&(c as u8)))
140    {
141        format!("{prefix}{literal}{suffix}")
142    } else {
143        // Not valid, so we need to hash the string
144        let mut output = String::with_capacity(max_length + prefix.len() + suffix.len());
145        output.push_str(&prefix.to_string());
146        let mut next = literal.chars();
147        while output.len() < max_length - hash_length + prefix.len() {
148            let Some(c) = next.next() else {
149                break;
150            };
151            if c <= '\u{007f}' && valid_section_chars.contains(&(c as u8)) {
152                output.push(c);
153            }
154        }
155
156        let mut hash = xx3::xx3hash(&literal);
157        while output.len() < max_length + prefix.len() {
158            let c = valid_section_chars[hash as usize % valid_section_chars.len()];
159            output.push(c as char);
160            hash = hash / valid_section_chars.len() as u64;
161        }
162        output.push_str(&suffix);
163        output
164    };
165
166    fn emit(tree: TokenStream, output: &str, found: &mut bool) -> TokenStream {
167        if *found {
168            return tree;
169        }
170        let mut stream = TokenStream::new();
171        for input in tree.into_iter() {
172            match input {
173                _ if *found => stream.extend([input]),
174                TokenTree::Ident(ident) if ident.to_string() == "__" => {
175                    stream.extend([TokenTree::Literal(Literal::string(&output))]);
176                    *found = true;
177                }
178                TokenTree::Group(group) => stream.extend([TokenTree::Group(Group::new(
179                    group.delimiter(),
180                    emit(group.stream(), output, found),
181                ))]),
182                _ => stream.extend([input]),
183            }
184        }
185        stream
186    }
187
188    let mut found = false;
189    let stream = emit(group, &output, &mut found);
190    if !found {
191        panic!("output: Expected to find __");
192    }
193    TokenStream::from_iter([TokenTree::Group(Group::new(Delimiter::None, stream))])
194}
195
196#[allow(unknown_lints, tail_expr_drop_order)]
197fn generate(
198    macro_type: &str,
199    macro_crate: &str,
200    attribute: TokenStream,
201    item: TokenStream,
202) -> TokenStream {
203    let mut inner = TokenStream::new();
204
205    // Search for crate_path in attributes
206    let mut crate_path = None;
207    let mut tokens = attribute.clone().into_iter().peekable();
208
209    while let Some(token) = tokens.next() {
210        if let TokenTree::Ident(ident) = &token {
211            if ident.to_string() == "crate_path" {
212                // Look for =
213                #[allow(unknown_lints, tail_expr_drop_order)]
214                if let Some(TokenTree::Punct(punct)) = tokens.next() {
215                    if punct.as_char() == '=' {
216                        // Collect tokens until comma or end
217                        let mut path = TokenStream::new();
218                        while let Some(token) = tokens.peek() {
219                            match token {
220                                TokenTree::Punct(p) if p.as_char() == ',' => {
221                                    tokens.next();
222                                    break;
223                                }
224                                _ => {
225                                    path.extend(std::iter::once(tokens.next().unwrap()));
226                                }
227                            }
228                        }
229                        crate_path = Some(path);
230                        break;
231                    }
232                }
233            }
234        }
235    }
236
237    if attribute.is_empty() {
238        // #[link_section]
239        inner.extend([
240            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
241            TokenTree::Group(Group::new(
242                Delimiter::Bracket,
243                TokenStream::from_iter([TokenTree::Ident(Ident::new(
244                    macro_type,
245                    Span::call_site(),
246                ))]),
247            )),
248        ]);
249    } else {
250        inner.extend([
251            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
252            TokenTree::Group(Group::new(
253                Delimiter::Bracket,
254                TokenStream::from_iter([
255                    TokenTree::Ident(Ident::new(macro_type, Span::call_site())),
256                    TokenTree::Group(Group::new(Delimiter::Parenthesis, attribute)),
257                ]),
258            )),
259        ]);
260    }
261
262    inner.extend(item);
263
264    let mut invoke = crate_path.unwrap_or_else(|| {
265        TokenStream::from_iter([
266            TokenTree::Punct(Punct::new(':', Spacing::Joint)),
267            TokenTree::Punct(Punct::new(':', Spacing::Alone)),
268            TokenTree::Ident(Ident::new(macro_crate, Span::call_site())),
269        ])
270    });
271
272    invoke.extend([
273        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
274        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
275        TokenTree::Ident(Ident::new("__support", Span::call_site())),
276        TokenTree::Punct(Punct::new(':', Spacing::Joint)),
277        TokenTree::Punct(Punct::new(':', Spacing::Alone)),
278        TokenTree::Ident(Ident::new(
279            &format!("{macro_type}_parse"),
280            Span::call_site(),
281        )),
282        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
283        TokenTree::Group(Group::new(Delimiter::Parenthesis, inner)),
284        TokenTree::Punct(Punct::new(';', Spacing::Alone)),
285    ]);
286
287    invoke
288}