solar_macros/symbols/
mod.rs

1//! Proc macro which builds the Symbol table.
2//!
3//! # Debugging
4//!
5//! Since this proc-macro does some non-trivial work, debugging it is important.
6//! This proc-macro can be invoked as an ordinary unit test, like so:
7//!
8//! ```bash
9//! cd crates/macros
10//! cargo test symbols::test_symbols -- --nocapture
11//! ```
12//!
13//! This unit test finds the `symbols!` invocation in `crates/interface/src/symbol.rs` and runs it.
14//! It verifies that the output token stream can be parsed as valid module items and that no errors
15//! were produced.
16//!
17//! You can also view the generated code by using `cargo expand`:
18//!
19//! ```bash
20//! # this is necessary only once
21//! cargo install cargo-expand
22//!
23//! cd crates/interface
24//! # The output is large.
25//! cargo +nightly expand > /tmp/solar_interface.rs
26//! ```
27
28use proc_macro2::{Span, TokenStream};
29use quote::quote;
30use std::collections::HashMap;
31use syn::{
32    Expr, Ident, Lit, LitStr, Macro, Token, braced,
33    parse::{Parse, ParseStream, Result},
34    punctuated::Punctuated,
35};
36
37#[cfg(test)]
38mod tests;
39
40mod kw {
41    syn::custom_keyword!(Keywords);
42    syn::custom_keyword!(Symbols);
43}
44
45struct Keyword {
46    name: Ident,
47    value: LitStr,
48}
49
50impl Parse for Keyword {
51    fn parse(input: ParseStream<'_>) -> Result<Self> {
52        let name = input.parse()?;
53        input.parse::<Token![:]>()?;
54        let value = input.parse()?;
55
56        Ok(Self { name, value })
57    }
58}
59
60struct Symbol {
61    name: Ident,
62    value: Value,
63}
64
65enum Value {
66    SameAsName,
67    String(LitStr),
68    Env(LitStr, Macro),
69    Unsupported(Expr),
70}
71
72impl Parse for Symbol {
73    fn parse(input: ParseStream<'_>) -> Result<Self> {
74        let name = input.parse()?;
75        let colon_token: Option<Token![:]> = input.parse()?;
76        let value = if colon_token.is_some() { input.parse()? } else { Value::SameAsName };
77
78        Ok(Self { name, value })
79    }
80}
81
82impl Parse for Value {
83    fn parse(input: ParseStream<'_>) -> Result<Self> {
84        let expr: Expr = input.parse()?;
85        match &expr {
86            Expr::Lit(expr) => {
87                if let Lit::Str(lit) = &expr.lit {
88                    return Ok(Self::String(lit.clone()));
89                }
90            }
91            Expr::Macro(expr) => {
92                if expr.mac.path.is_ident("env")
93                    && let Ok(lit) = expr.mac.parse_body()
94                {
95                    return Ok(Self::Env(lit, expr.mac.clone()));
96                }
97            }
98            _ => {}
99        }
100        Ok(Self::Unsupported(expr))
101    }
102}
103
104struct Input {
105    keywords: Punctuated<Keyword, Token![,]>,
106    symbols: Punctuated<Symbol, Token![,]>,
107}
108
109impl Parse for Input {
110    fn parse(input: ParseStream<'_>) -> Result<Self> {
111        input.parse::<kw::Keywords>()?;
112        let content;
113        braced!(content in input);
114        let keywords = Punctuated::parse_terminated(&content)?;
115
116        input.parse::<kw::Symbols>()?;
117        let content;
118        braced!(content in input);
119        let symbols = Punctuated::parse_terminated(&content)?;
120
121        Ok(Self { keywords, symbols })
122    }
123}
124
125#[derive(Default)]
126struct Errors {
127    list: Vec<syn::Error>,
128}
129
130impl Errors {
131    fn error(&mut self, span: Span, msg: String) {
132        self.list.push(syn::Error::new(span, msg));
133    }
134}
135
136pub fn symbols(input: TokenStream) -> TokenStream {
137    let (mut output, errors) = symbols_with_errors(input);
138
139    // If we generated any errors, then report them as compiler_error!() macro calls.
140    // This lets the errors point back to the most relevant span. It also allows us
141    // to report as many errors as we can during a single run.
142    output.extend(errors.into_iter().map(|e| e.to_compile_error()));
143
144    output
145}
146
147struct Preinterned {
148    idx: u32,
149    span_of_name: Span,
150}
151
152struct Entries {
153    map: HashMap<String, Preinterned>,
154}
155
156impl Entries {
157    fn with_capacity(capacity: usize) -> Self {
158        Self { map: HashMap::with_capacity(capacity) }
159    }
160
161    fn insert(&mut self, span: Span, str: &str, errors: &mut Errors) -> u32 {
162        if let Some(prev) = self.map.get(str) {
163            errors.error(span, format!("Symbol `{str}` is duplicated"));
164            errors.error(prev.span_of_name, "location of previous definition".to_string());
165            prev.idx
166        } else {
167            let idx = self.len();
168            self.map.insert(str.to_string(), Preinterned { idx, span_of_name: span });
169            idx
170        }
171    }
172
173    fn len(&self) -> u32 {
174        u32::try_from(self.map.len()).expect("way too many symbols")
175    }
176}
177
178fn symbols_with_errors(input: TokenStream) -> (TokenStream, Vec<syn::Error>) {
179    let mut errors = Errors::default();
180
181    let input: Input = match syn::parse2(input) {
182        Ok(input) => input,
183        Err(e) => {
184            // This allows us to display errors at the proper span, while minimizing
185            // unrelated errors caused by bailing out (and not generating code).
186            errors.list.push(e);
187            Input { keywords: Default::default(), symbols: Default::default() }
188        }
189    };
190
191    let mut keyword_stream = quote! {};
192    let mut symbols_stream = quote! {};
193    let mut prefill_stream = quote! {};
194    let mut entries = Entries::with_capacity(input.keywords.len() + input.symbols.len() + 10);
195    let mut prev_key: Option<(Span, String)> = None;
196
197    let mut check_order = |span: Span, str: &str, errors: &mut Errors| {
198        if let Some((prev_span, ref prev_str)) = prev_key
199            && str < prev_str
200        {
201            errors.error(span, format!("Symbol `{str}` must precede `{prev_str}`"));
202            errors.error(prev_span, format!("location of previous symbol `{prev_str}`"));
203        }
204        prev_key = Some((span, str.to_string()));
205    };
206
207    // Generate the listed keywords.
208    for keyword in input.keywords.iter() {
209        let name = &keyword.name;
210        let value = &keyword.value;
211        let value_string = value.value();
212        let idx = entries.insert(keyword.name.span(), &value_string, &mut errors);
213        prefill_stream.extend(quote! {
214            #value,
215        });
216        keyword_stream.extend(quote! {
217            pub const #name: Symbol = Symbol::new(#idx);
218        });
219    }
220
221    // Generate the listed symbols.
222    for symbol in input.symbols.iter() {
223        let name = &symbol.name;
224        check_order(symbol.name.span(), &name.to_string(), &mut errors);
225
226        let value = match &symbol.value {
227            Value::SameAsName => name.to_string(),
228            Value::String(lit) => lit.value(),
229            Value::Env(..) => continue, // in another loop below
230            Value::Unsupported(expr) => {
231                errors.list.push(syn::Error::new_spanned(
232                    expr,
233                    concat!(
234                        "unsupported expression for symbol value; implement support for this in ",
235                        file!(),
236                    ),
237                ));
238                continue;
239            }
240        };
241        let idx = entries.insert(symbol.name.span(), &value, &mut errors);
242
243        prefill_stream.extend(quote! {
244            #value,
245        });
246        symbols_stream.extend(quote! {
247            pub const #name: Symbol = Symbol::new(#idx);
248        });
249    }
250
251    // Generate symbols for the strings "0", "1", ..., "9".
252    for n in 0..10 {
253        let n = n.to_string();
254        entries.insert(Span::call_site(), &n, &mut errors);
255        prefill_stream.extend(quote! {
256            #n,
257        });
258    }
259
260    // Symbols whose value comes from an environment variable. It's allowed for
261    // these to have the same value as another symbol.
262    for symbol in &input.symbols {
263        let (env_var, expr) = match &symbol.value {
264            Value::Env(lit, expr) => (lit, expr),
265            Value::SameAsName | Value::String(_) | Value::Unsupported(_) => continue,
266        };
267
268        if !proc_macro::is_available() {
269            errors.error(
270                Span::call_site(),
271                "proc_macro::tracked_env is not available in unit test".to_owned(),
272            );
273            break;
274        }
275
276        let value = match std::env::var(env_var.value()) {
277            Ok(value) => value,
278            Err(err) => {
279                errors.list.push(syn::Error::new_spanned(expr, err));
280                continue;
281            }
282        };
283
284        let idx = if let Some(prev) = entries.map.get(&value) {
285            prev.idx
286        } else {
287            prefill_stream.extend(quote! {
288                #value,
289            });
290            entries.insert(symbol.name.span(), &value, &mut errors)
291        };
292
293        let name = &symbol.name;
294        symbols_stream.extend(quote! {
295            pub const #name: Symbol = Symbol::new(#idx);
296        });
297    }
298
299    let symbol_digits_base = entries.map["0"].idx;
300    let preinterned_symbols_count = entries.len();
301    let output = quote! {
302        const PREINTERNED: &[&str] = &[#prefill_stream];
303        const SYMBOL_DIGITS_BASE: u32 = #symbol_digits_base;
304        const PREINTERNED_SYMBOLS_COUNT: u32 = #preinterned_symbols_count;
305
306        #[allow(non_upper_case_globals)]
307        #[doc(hidden)]
308        mod kw_generated {
309            use super::Symbol;
310            #keyword_stream
311        }
312
313        #[allow(non_upper_case_globals)]
314        #[doc(hidden)]
315        mod sym_generated {
316            use super::Symbol;
317            #symbols_stream
318        }
319    };
320
321    (output, errors.list)
322}