eval_macro_internal/
lib.rs

1#![allow(clippy::panic)]
2#![allow(clippy::expect_used)]
3
4use proc_macro2::Delimiter;
5use proc_macro2::LineColumn;
6use proc_macro2::TokenStream;
7use proc_macro2::TokenTree;
8use quote::quote;
9use std::collections::hash_map::DefaultHasher;
10use std::fs::File;
11use std::fs;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::io::Write;
15use std::path::Path;
16use std::path::PathBuf;
17use std::process::Command;
18
19// =================
20// === Constants ===
21// =================
22
23/// Set to 'true' to enable debug prints.
24const DEBUG: bool = false;
25
26const NAME: &str = "eval_macro";
27
28/// Rust keywords for special handling. This is not needed for this macro to work, it is only used
29/// to make `IntelliJ` / `RustRover` work correctly, as their `TokenStream` spans are incorrect.
30const KEYWORDS: &[&str] = &[
31    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
32    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
33    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
34    "type", "unsafe", "use", "where", "while", "abstract", "become", "box", "do", "final", "macro",
35    "override", "priv", "typeof", "unsized", "virtual", "yield", "try",
36];
37
38/// Common functions. After this lib stabilizes, they should be imported from this library, not
39/// injected.
40const PRELUDE: &str = "
41    macro_rules! write_ln {
42        ($target:ident, $($ts:tt)*) => {
43            $target.push_str(&format!( $($ts)* ));
44            $target.push_str(\"\n\");
45        };
46    }
47
48    fn sum_combinations(n: usize) -> Vec<Vec<usize>> {
49        let mut result = Vec::new();
50
51        fn generate(n: usize, current: Vec<usize>, result: &mut Vec<Vec<usize>>) {
52            if n == 0 {
53                if current.len() > 1 {
54                    result.push(current);
55                }
56                return;
57            }
58
59            for i in 1..=n {
60                let mut next = current.clone();
61                next.push(i);
62                generate(n - i, next, result);
63            }
64        }
65
66        generate(n, vec![], &mut result);
67        result
68    }
69";
70
71// ============
72// === Path ===
73// ============
74
75/// Output directory for projects generated by this macro.
76fn get_output_dir() -> PathBuf {
77    let home_dir = std::env::var("HOME")
78        .expect("HOME environment variable not set — this is required to locate ~/.cargo.");
79
80    let eval_macro_dir = PathBuf::from(home_dir)
81        .join(".cargo")
82        .join(NAME);
83
84    if !eval_macro_dir.exists() {
85        fs::create_dir_all(&eval_macro_dir)
86            .expect("Failed to create ~/.cargo/eval_macro directory.");
87    }
88
89    eval_macro_dir
90}
91
92// ==========================
93// === Project Management ===
94// ==========================
95
96#[derive(Debug)]
97struct CargoConfig {
98    edition: Option<String>,
99    dependencies: Vec<String>
100}
101
102impl CargoConfig {
103    fn new(dependencies: Vec<String>) -> Self {
104        Self {
105            edition: None,
106            dependencies
107        }
108    }
109
110    fn print(&self) -> String {
111        let edition = self.edition.as_ref().map_or("2024", |t| t.as_str());
112        let dependencies = self.dependencies.join("\n");
113        format!("
114            [package]
115            name = \"eval_project\"
116            version = \"1.0.0\"
117            edition = \"{edition}\"
118
119            [dependencies]
120            {dependencies}
121        ")
122    }
123}
124
125fn project_name_from_input(input_str: &str) -> String {
126    let mut hasher = DefaultHasher::new();
127    input_str.hash(&mut hasher);
128    format!("project_{:016x}", hasher.finish())
129}
130
131fn create_project_skeleton(project_dir: &Path, cargo_config: CargoConfig, main_content: &str) {
132    let src_dir = project_dir.join("src");
133    if !src_dir.exists() {
134        fs::create_dir_all(&src_dir).expect("Failed to create src directory.");
135    }
136
137    let cargo_toml = project_dir.join("Cargo.toml");
138    let cargo_toml_content = cargo_config.print();
139    fs::write(&cargo_toml, cargo_toml_content).expect("Failed to write Cargo.toml.");
140
141    let main_rs = src_dir.join("main.rs");
142    let mut file = File::create(&main_rs).expect("Failed to create main.rs");
143    file.write_all(main_content.as_bytes()).expect("Failed to write main.rs");
144}
145
146fn run_cargo_project(project_dir: &PathBuf) -> String {
147    let output = Command::new("cargo")
148        .arg("run")
149        .current_dir(project_dir)
150        .output()
151        .expect("Failed to execute cargo run");
152
153    if !output.status.success() {
154        let stderr = String::from_utf8_lossy(&output.stderr);
155        eprintln!("{stderr}");
156        panic!("Cargo project failed to compile or run.");
157    }
158
159    String::from_utf8_lossy(&output.stdout).to_string()
160}
161
162// ===================================
163// === Top-level Attribute Parsing ===
164// ===================================
165
166fn extract_dependencies(tokens: TokenStream) -> (TokenStream, Vec<String>) {
167    let (tokens2, attributes) = extract_top_level_attributes(tokens);
168    let mut dependencies = Vec::new();
169    for attr in attributes {
170        let pfx = "dependency(";
171        if attr.starts_with(pfx) && attr.ends_with(")") {
172            dependencies.push(attr[pfx.len()..attr.len()-1].to_string());
173        } else {
174            panic!("Invalid attribute: {attr}");
175        }
176    }
177    (tokens2, dependencies)
178}
179
180/// Extract all top-level attributes (`#![...]`) from the input `TokenStream` and return the
181/// remaining `TokenStream`.
182fn extract_top_level_attributes(tokens: TokenStream) -> (TokenStream, Vec<String>) {
183    let mut output = TokenStream::new();
184    let mut attributes = Vec::new();
185    let mut iter = tokens.into_iter().peekable();
186    while let Some(_token) = iter.peek() {
187        if let Some(dep) = try_parse_inner_attr(&mut iter) {
188            attributes.push(dep);
189        } else if let Some(token) = iter.next() {
190            output.extend(Some(token));
191        }
192    }
193    (output, attributes)
194}
195
196/// Try to parse `#![...]` as an inner attribute and return the parsed content if successful.
197fn try_parse_inner_attr(iter: &mut std::iter::Peekable<impl Iterator<Item = TokenTree>>) -> Option<String> {
198    // Check for '#'.
199    let Some(TokenTree::Punct(pound)) = iter.peek() else { return None; };
200    if pound.as_char() != '#' { return None; }
201    iter.next();
202
203    // Check for '!'.
204    let Some(TokenTree::Punct(bang)) = iter.peek() else { return None; };
205    if bang.as_char() != '!' { return None; }
206    iter.next();
207
208    // Check for [ ... ] group.
209    let Some(TokenTree::Group(group)) = iter.peek() else { return None; };
210    if group.delimiter() != Delimiter::Bracket { return None; }
211    let content = group.stream().to_string();
212    iter.next();
213
214    Some(content)
215}
216
217// ====================
218// === Output Macro ===
219// ====================
220
221/// Find and expand the `output!` macro in the input `TokenStream`. After this lib stabilizes, this
222/// should be rewritten to standard macro and imported by the generated code.
223fn expand_output_macro(input: TokenStream) -> TokenStream {
224    let tokens: Vec<TokenTree> = input.into_iter().collect();
225    let mut output = TokenStream::new();
226    let mut i = 0;
227    while i < tokens.len() {
228        if let TokenTree::Ident(ref ident) = tokens[i] {
229            if *ident == "output" && i + 1 < tokens.len() {
230                if let TokenTree::Punct(ref excl) = tokens[i + 1] {
231                    if excl.as_char() == '!' && i + 2 < tokens.len() {
232                        if let TokenTree::Group(ref group) = tokens[i + 2] {
233                            let inner_rewritten = expand_output_macro(group.stream());
234                            let content_str = print(&inner_rewritten);
235                            let lit = syn::LitStr::new(&content_str, proc_macro2::Span::call_site());
236                            let new_tokens = quote! { write_ln!(output_buffer, #lit); };
237                            output.extend(new_tokens);
238                            i += 3;
239                            continue;
240                        }
241                    }
242                }
243            }
244        }
245        match &tokens[i] {
246            TokenTree::Group(group) => {
247                let new_stream = expand_output_macro(group.stream());
248                let new_group = TokenTree::Group(proc_macro2::Group::new(group.delimiter(), new_stream));
249                output.extend(std::iter::once(new_group));
250            }
251            _ => {
252                output.extend(std::iter::once(tokens[i].clone()));
253            }
254        }
255        i += 1;
256    }
257    output
258}
259
260// =============
261// === Print ===
262// =============
263
264#[derive(Debug)]
265struct PrintOutput {
266    output: String,
267    start_token: Option<LineColumn>,
268    end_token: Option<LineColumn>,
269}
270
271// Used to indicate that `{{` and `}}` was already collapsed to `{` and `}`. So, for example, the
272// following transformations will be performed:
273// `{ {{a}} }` -> `{ {{{a}}} }` -> `{ %%%{a}%%% }` -> `{{ %%%{a}%%% }}` -> `{{ {a} }}`
274const SPACER: &str = "%%%";
275
276/// Prints the token stream as a string ready to be used by the format macro. The following
277/// transformations are performed:
278/// - A trailing space is added after printing each token.
279/// - If `prev_token.line == next_token.line` and `prev_token.end_column >= next_token.start_column`,
280///   the prev token trailing space is removed. This basically preserves spaces from the input code,
281///   which is needed to distinguish between such inputs as `MyName{x}`, `MyName {x}`, etc.
282///   The `>=` is used because sometimes a few tokens can have the same start column. For example,
283///   for lifetimes (`'t`), the apostrophe and the lifetime ident have the same start column.
284/// - Every occurrence of `{{` and `}}` is replaced with `{` and `}` respectively. Also, every
285///   occurrence of `{` and `}` is replaced with `{{` and `}}` respectively. This prepares the
286///   string to be used in the format macro.
287///
288/// There is also a special transformation for `IntellIJ` / `RustRover`. Their spans are different from
289/// rustc ones, so sometimes tokens are glued together. This is why we discover Rust keywords and add
290/// additional spacing around. This is probably not covering all the bugs. If there will be a bug
291/// report, this is the place to look at.
292fn print(tokens: &TokenStream) -> String {
293    print_internal(tokens).output.replace("{%%%", "{ %%%").replace("%%%}", "%%% }").replace(SPACER, "")
294}
295
296fn print_internal(tokens: &TokenStream) -> PrintOutput {
297    let token_vec: Vec<TokenTree> = tokens.clone().into_iter().collect();
298    let mut output = String::new();
299    let mut first_token_start = None;
300    let mut prev_token_end: Option<LineColumn> = None;
301    for (i, token) in token_vec.iter().enumerate() {
302        let mut token_start = token.span().start();
303        let mut token_end = token.span().end();
304        let mut is_keyword = false;
305        let token_str = match token {
306            TokenTree::Group(g) => {
307                let content = print_internal(&g.stream());
308                let mut content_str = content.output;
309                content_str.pop();
310                let (open, close) = match g.delimiter() {
311                    Delimiter::Brace =>{
312                        if content_str.starts_with('{') && content_str.ends_with('}') {
313                            // We already replaced the internal `{` and `}` with `{{` and `}}`.
314                            content_str.pop();
315                            content_str.remove(0);
316                            (SPACER, SPACER)
317                        } else {
318                            ("{{", "}}")
319                        }
320                    },
321                    Delimiter::Parenthesis => ("(", ")"),
322                    Delimiter::Bracket => ("[", "]"),
323                    _ => ("", ""),
324                };
325
326                if let Some(content_first_token_start) = content.start_token {
327                    token_start.line = content_first_token_start.line;
328                    if content_first_token_start.column > 0 {
329                        token_start.column = content_first_token_start.column - 1;
330                    }
331                }
332                if let Some(content_end) = content.end_token {
333                    token_end.line = content_end.line;
334                    token_end.column = content_end.column + 1;
335                }
336                format!("{open}{content_str}{close}")
337            }
338            TokenTree::Ident(ident) => {
339                let str = ident.to_string();
340                is_keyword = KEYWORDS.contains(&str.as_str());
341                str
342            },
343            TokenTree::Literal(lit) => lit.to_string(),
344            TokenTree::Punct(punct) => punct.as_char().to_string(),
345        };
346        if DEBUG {
347            println!("{i}: [{token_start:?}-{token_end:?}] [{prev_token_end:?}]: {token}");
348        }
349        if let Some(prev_token_end) = prev_token_end {
350            if prev_token_end.line == token_start.line && prev_token_end.column >= token_start.column {
351                output.pop();
352            }
353        }
354
355        // Pushing a space before and after keywords is for IntelliJ only. Their token spans are invalid.
356        if is_keyword { output.push(' '); }
357        output.push_str(&token_str);
358        output.push(' ');
359        if is_keyword { output.push(' '); }
360
361        first_token_start.get_or_insert(token_start);
362        prev_token_end = Some(token_end);
363    }
364    PrintOutput {
365        output,
366        start_token: first_token_start,
367        end_token: prev_token_end,
368    }
369}
370
371// ==================
372// === Eval Macro ===
373// ==================
374
375#[proc_macro]
376pub fn eval(input_raw: proc_macro::TokenStream) -> proc_macro::TokenStream {
377    let (input, dependencies) = extract_dependencies(input_raw.into());
378    let cargo_config = CargoConfig::new(dependencies);
379
380    let input_str = expand_output_macro(input).to_string();
381    if DEBUG { println!("REWRITTEN INPUT: {input_str}"); }
382
383    let out_dir = get_output_dir();
384    let project_name = project_name_from_input(&input_str);
385    let project_dir = out_dir.join(&project_name);
386
387    if !project_dir.exists() {
388        fs::create_dir_all(&project_dir)
389            .expect("Failed to create project directory.");
390    }
391
392    let main_content = format!(
393        "{PRELUDE}
394        fn main() {{
395            let mut output_buffer = String::new();
396            {input_str}
397            println!(\"{{output_buffer}}\");
398        }}",
399    );
400
401    create_project_skeleton(&project_dir, cargo_config, &main_content);
402    let output = run_cargo_project(&project_dir);
403    fs::remove_dir_all(&project_dir).ok();
404    let out: TokenStream = output.parse().expect("Failed to parse generated code.");
405    if DEBUG {
406        println!("OUT: {out}");
407    }
408    out.into()
409}