eval_macro_internal/
lib.rs

1#![allow(clippy::panic)]
2#![allow(clippy::expect_used)]
3
4use proc_macro2::Delimiter;
5use proc_macro2::LineColumn;
6use proc_macro2::TokenStream;
7use proc_macro2::TokenTree;
8use quote::quote;
9use std::collections::hash_map::DefaultHasher;
10use std::fs::File;
11use std::fs;
12use std::hash::Hash;
13use std::hash::Hasher;
14use std::io::Write;
15use std::path::Path;
16use std::path::PathBuf;
17use std::process::Command;
18
19// =================
20// === Constants ===
21// =================
22
23/// Set to 'true' to enable debug prints.
24const DEBUG: bool = false;
25
26const NAME: &str = "eval_macro";
27
28const OUTPUT_PREFIX: &str = "OUTPUT:";
29const WARNING_PREFIX: &str = "WARNING:";
30const ERROR_PREFIX: &str = "ERROR:";
31
32/// Rust keywords for special handling. This is not needed for this macro to work, it is only used
33/// to make `IntelliJ` / `RustRover` work correctly, as their `TokenStream` spans are incorrect.
34const KEYWORDS: &[&str] = &[
35    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
36    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
37    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
38    "type", "unsafe", "use", "where", "while", "abstract", "become", "box", "do", "final", "macro",
39    "override", "priv", "typeof", "unsized", "virtual", "yield", "try",
40];
41
42/// Common functions. After this lib stabilizes, they should be imported from this library, not
43/// injected.
44const PRELUDE: &str = "
45    macro_rules! write_ln {
46        ($target:ident, $($ts:tt)*) => {
47            $target.push_str(&format!( $($ts)* ));
48            $target.push_str(\"\n\");
49        };
50    }
51
52    macro_rules! println_output {
53        ($($ts:tt)*) => {
54            println!(\"{}\", prefix_lines_with_output(&format!( $($ts)* )));
55        };
56    }
57
58    macro_rules! println_warning {
59        ($($ts:tt)*) => {
60            println!(\"{}\", prefix_lines_with_warning(&format!( $($ts)* )));
61        };
62    }
63
64    macro_rules! println_error {
65        ($($ts:tt)*) => {
66            println!(\"{}\", prefix_lines_with_error(&format!( $($ts)* )));
67        };
68    }
69
70    fn prefix_lines_with(prefix: &str, input: &str) -> String {
71        input
72            .lines()
73            .map(|line| format!(\"{prefix}: {line}\"))
74            .collect::<Vec<_>>()
75            .join(\"\\n\")
76    }
77
78    fn prefix_lines_with_output(input: &str) -> String {
79        prefix_lines_with(\"OUTPUT\", input)
80    }
81
82    fn prefix_lines_with_warning(input: &str) -> String {
83        prefix_lines_with(\"WARNING\", input)
84    }
85
86    fn prefix_lines_with_error(input: &str) -> String {
87        prefix_lines_with(\"ERROR\", input)
88    }
89
90    fn sum_combinations(n: usize) -> Vec<Vec<usize>> {
91        let mut result = Vec::new();
92
93        fn generate(n: usize, current: Vec<usize>, result: &mut Vec<Vec<usize>>) {
94            if n == 0 {
95                if current.len() > 1 {
96                    result.push(current);
97                }
98                return;
99            }
100
101            for i in 1..=n {
102                let mut next = current.clone();
103                next.push(i);
104                generate(n - i, next, result);
105            }
106        }
107
108        generate(n, vec![], &mut result);
109        result
110    }
111
112    fn push_as_str<T: std::fmt::Debug>(str: &mut String, value: &T) {
113        let repr = format!(\"{value:?}\");
114        if repr != \"()\" {
115            if repr.starts_with(\"(\") && repr.ends_with(\")\") {
116                str.push_str(&repr[1..repr.len() - 1]);
117            } else {
118                str.push_str(&repr);
119            }
120        }
121    }
122";
123
124// ============
125// === Path ===
126// ============
127
128/// Output directory for projects generated by this macro.
129fn get_output_dir() -> PathBuf {
130    let home_dir = std::env::var("HOME")
131        .expect("HOME environment variable not set — this is required to locate ~/.cargo.");
132
133    let eval_macro_dir = PathBuf::from(home_dir)
134        .join(".cargo")
135        .join(NAME);
136
137    if !eval_macro_dir.exists() {
138        fs::create_dir_all(&eval_macro_dir)
139            .expect("Failed to create ~/.cargo/eval_macro directory.");
140    }
141
142    eval_macro_dir
143}
144
145// ==========================
146// === Project Management ===
147// ==========================
148
149#[derive(Debug, Default)]
150struct ProjectConfig {
151    cargo: CargoConfig,
152    lib: LibConfig,
153}
154
155#[derive(Debug, Default)]
156struct CargoConfig {
157    edition: Option<String>,
158    resolver: Option<String>,
159    dependencies: Vec<String>,
160}
161
162#[derive(Debug, Default)]
163struct LibConfig {
164    features: Vec<String>,
165    allow: Vec<String>,
166    expect: Vec<String>,
167    warn: Vec<String>,
168    deny: Vec<String>,
169    forbid: Vec<String>,
170}
171
172impl CargoConfig {
173    fn print(&self) -> String {
174        let edition = self.edition.as_ref().map_or("2024", |t| t.as_str());
175        let resolver = self.resolver.as_ref().map_or("3", |t| t.as_str());
176        let dependencies = self.dependencies.join("\n");
177        format!("
178            [package]
179            name     = \"eval_project\"
180            version  = \"1.0.0\"
181            edition  = \"{edition}\"
182            resolver = \"{resolver}\"
183
184            [dependencies]
185            {dependencies}
186        ")
187    }
188}
189
190impl LibConfig {
191    fn print(&self) -> String {
192        let mut out = vec![];
193        out.extend(self.features.iter().map(|t| format!("#![feature({})]", t)));
194        out.extend(self.allow.iter().map(|t| format!("#![allow({})]", t)));
195        out.extend(self.expect.iter().map(|t| format!("#![expect({})]", t)));
196        out.extend(self.warn.iter().map(|t| format!("#![warn({})]", t)));
197        out.extend(self.deny.iter().map(|t| format!("#![deny({})]", t)));
198        out.extend(self.forbid.iter().map(|t| format!("#![forbid({})]", t)));
199        out.join("\n")
200    }
201}
202
203fn project_name_from_input(input_str: &str) -> String {
204    let mut hasher = DefaultHasher::new();
205    input_str.hash(&mut hasher);
206    format!("project_{:016x}", hasher.finish())
207}
208
209fn create_project_skeleton(project_dir: &Path, cfg: ProjectConfig, main_content: &str) {
210    let src_dir = project_dir.join("src");
211    if !src_dir.exists() {
212        fs::create_dir_all(&src_dir).expect("Failed to create src directory.");
213    }
214
215    let cargo_toml = project_dir.join("Cargo.toml");
216    let cargo_toml_content = cfg.cargo.print();
217    fs::write(&cargo_toml, cargo_toml_content).expect("Failed to write Cargo.toml.");
218
219    let main_rs = src_dir.join("main.rs");
220    let mut file = File::create(&main_rs).expect("Failed to create main.rs");
221    file.write_all(main_content.as_bytes()).expect("Failed to write main.rs");
222}
223
224fn run_cargo_project(project_dir: &PathBuf) -> String {
225    let output = Command::new("cargo")
226        .arg("run")
227        .current_dir(project_dir)
228        .output()
229        .expect("Failed to execute cargo run");
230
231    if !output.status.success() {
232        let stderr = String::from_utf8_lossy(&output.stderr);
233        eprintln!("{stderr}");
234        panic!("Cargo project failed to compile or run.");
235    }
236
237    String::from_utf8_lossy(&output.stdout).to_string()
238}
239
240// ===================================
241// === Top-level Attribute Parsing ===
242// ===================================
243
244fn extract_dependencies(cfg: &mut ProjectConfig, tokens: TokenStream) -> TokenStream {
245    let (tokens2, attributes) = extract_top_level_attributes(tokens);
246    for attr in attributes {
247        if let Some(pos) = attr.find('(') {
248            let name = &attr[..pos];
249            let value = (&attr[pos + 1.. attr.len() - 1]).to_string(); // Skipping ")".
250            match name {
251                "dependency" => cfg.cargo.dependencies.push(value),
252                "edition" => {
253                    if cfg.cargo.edition.is_some() {
254                        panic!("Edition already set.");
255                    }
256                    cfg.cargo.edition = Some(value);
257                },
258                "resolver" => {
259                    if cfg.cargo.resolver.is_some() {
260                        panic!("Resolver already set.");
261                    }
262                    cfg.cargo.resolver = Some(value);
263                },
264
265                "feature" => cfg.lib.features.push(value),
266                "allow" => cfg.lib.allow.push(value),
267                "expect" => cfg.lib.expect.push(value),
268                "warn" => cfg.lib.warn.push(value),
269                "deny" => cfg.lib.deny.push(value),
270                "forbid" => cfg.lib.forbid.push(value),
271                _ => panic!("Invalid attribute: {attr}"),
272            }
273        } else {
274            panic!("Invalid attribute: {attr}");
275        }
276    }
277    tokens2
278}
279
280/// Extract all top-level attributes (`#![...]`) from the input `TokenStream` and return the
281/// remaining `TokenStream`.
282fn extract_top_level_attributes(tokens: TokenStream) -> (TokenStream, Vec<String>) {
283    let mut output = TokenStream::new();
284    let mut attributes = Vec::new();
285    let mut iter = tokens.into_iter().peekable();
286    while let Some(_token) = iter.peek() {
287        if let Some(dep) = try_parse_inner_attr(&mut iter) {
288            attributes.push(dep);
289        } else if let Some(token) = iter.next() {
290            output.extend(Some(token));
291        }
292    }
293    (output, attributes)
294}
295
296/// Try to parse `#![...]` as an inner attribute and return the parsed content if successful.
297fn try_parse_inner_attr(iter: &mut std::iter::Peekable<impl Iterator<Item = TokenTree>>) -> Option<String> {
298    // Check for '#'.
299    let Some(TokenTree::Punct(pound)) = iter.peek() else { return None; };
300    if pound.as_char() != '#' { return None; }
301    iter.next();
302
303    // Check for '!'.
304    let Some(TokenTree::Punct(bang)) = iter.peek() else { return None; };
305    if bang.as_char() != '!' { return None; }
306    iter.next();
307
308    // Check for [ ... ] group.
309    let Some(TokenTree::Group(group)) = iter.peek() else { return None; };
310    if group.delimiter() != Delimiter::Bracket { return None; }
311    let content = group.stream().to_string();
312    iter.next();
313
314    Some(content)
315}
316
317// ====================
318// === Output Macro ===
319// ====================
320
321/// Find and expand the `output!` macro in the input `TokenStream`. After this lib stabilizes, this
322/// should be rewritten to standard macro and imported by the generated code.
323fn expand_output_macro(input: TokenStream) -> TokenStream {
324    let tokens: Vec<TokenTree> = input.into_iter().collect();
325    let mut output = TokenStream::new();
326    let mut i = 0;
327    while i < tokens.len() {
328        if let TokenTree::Ident(ref ident) = tokens[i] {
329            if *ident == "output" && i + 1 < tokens.len() {
330                if let TokenTree::Punct(ref excl) = tokens[i + 1] {
331                    if excl.as_char() == '!' && i + 2 < tokens.len() {
332                        if let TokenTree::Group(ref group) = tokens[i + 2] {
333                            let inner_rewritten = expand_output_macro(group.stream());
334                            let content_str = print(&inner_rewritten);
335                            let lit = syn::LitStr::new(&content_str, proc_macro2::Span::call_site());
336                            let new_tokens = quote! { write_ln!(output_buffer, #lit); };
337                            output.extend(new_tokens);
338                            i += 3;
339                            continue;
340                        }
341                    }
342                }
343            }
344        }
345        match &tokens[i] {
346            TokenTree::Group(group) => {
347                let new_stream = expand_output_macro(group.stream());
348                let new_group = TokenTree::Group(proc_macro2::Group::new(group.delimiter(), new_stream));
349                output.extend(std::iter::once(new_group));
350            }
351            _ => {
352                output.extend(std::iter::once(tokens[i].clone()));
353            }
354        }
355        i += 1;
356    }
357    output
358}
359
360// =============
361// === Print ===
362// =============
363
364#[derive(Debug)]
365struct PrintOutput {
366    output: String,
367    start_token: Option<LineColumn>,
368    end_token: Option<LineColumn>,
369}
370
371// Used to indicate that `{{` and `}}` was already collapsed to `{` and `}`. So, for example, the
372// following transformations will be performed:
373// `{ {{a}} }` -> `{ {{{a}}} }` -> `{ %%%{a}%%% }` -> `{{ %%%{a}%%% }}` -> `{{ {a} }}`
374const SPACER: &str = "%%%";
375
376/// Prints the token stream as a string ready to be used by the format macro. The following
377/// transformations are performed:
378/// - A trailing space is added after printing each token.
379/// - If `prev_token.line == next_token.line` and `prev_token.end_column >= next_token.start_column`,
380///   the prev token trailing space is removed. This basically preserves spaces from the input code,
381///   which is needed to distinguish between such inputs as `MyName{x}`, `MyName {x}`, etc.
382///   The `>=` is used because sometimes a few tokens can have the same start column. For example,
383///   for lifetimes (`'t`), the apostrophe and the lifetime ident have the same start column.
384/// - Every occurrence of `{{` and `}}` is replaced with `{` and `}` respectively. Also, every
385///   occurrence of `{` and `}` is replaced with `{{` and `}}` respectively. This prepares the
386///   string to be used in the format macro.
387///
388/// There is also a special transformation for `IntellIJ` / `RustRover`. Their spans are different from
389/// rustc ones, so sometimes tokens are glued together. This is why we discover Rust keywords and add
390/// additional spacing around. This is probably not covering all the bugs. If there will be a bug
391/// report, this is the place to look at.
392fn print(tokens: &TokenStream) -> String {
393    print_internal(tokens).output.replace("{%%%", "{ %%%").replace("%%%}", "%%% }").replace(SPACER, "")
394}
395
396fn print_internal(tokens: &TokenStream) -> PrintOutput {
397    let token_vec: Vec<TokenTree> = tokens.clone().into_iter().collect();
398    let mut output = String::new();
399    let mut first_token_start = None;
400    let mut prev_token_end: Option<LineColumn> = None;
401    for (i, token) in token_vec.iter().enumerate() {
402        let mut token_start = token.span().start();
403        let mut token_end = token.span().end();
404        let mut is_keyword = false;
405        let token_str = match token {
406            TokenTree::Group(g) => {
407                let content = print_internal(&g.stream());
408                let mut content_str = content.output;
409                content_str.pop();
410                let (open, close) = match g.delimiter() {
411                    Delimiter::Brace =>{
412                        if content_str.starts_with('{') && content_str.ends_with('}') {
413                            // We already replaced the internal `{` and `}` with `{{` and `}}`.
414                            content_str.pop();
415                            content_str.remove(0);
416                            (SPACER, SPACER)
417                        } else {
418                            ("{{", "}}")
419                        }
420                    },
421                    Delimiter::Parenthesis => ("(", ")"),
422                    Delimiter::Bracket => ("[", "]"),
423                    _ => ("", ""),
424                };
425
426                if let Some(content_first_token_start) = content.start_token {
427                    token_start.line = content_first_token_start.line;
428                    if content_first_token_start.column > 0 {
429                        token_start.column = content_first_token_start.column - 1;
430                    }
431                }
432                if let Some(content_end) = content.end_token {
433                    token_end.line = content_end.line;
434                    token_end.column = content_end.column + 1;
435                }
436                format!("{open}{content_str}{close}")
437            }
438            TokenTree::Ident(ident) => {
439                let str = ident.to_string();
440                is_keyword = KEYWORDS.contains(&str.as_str());
441                str
442            },
443            TokenTree::Literal(lit) => lit.to_string(),
444            TokenTree::Punct(punct) => punct.as_char().to_string(),
445        };
446        if DEBUG {
447            println!("{i}: [{token_start:?}-{token_end:?}] [{prev_token_end:?}]: {token}");
448        }
449        if let Some(prev_token_end) = prev_token_end {
450            if prev_token_end.line == token_start.line && prev_token_end.column >= token_start.column {
451                output.pop();
452            }
453        }
454
455        // Pushing a space before and after keywords is for IntelliJ only. Their token spans are invalid.
456        if is_keyword { output.push(' '); }
457        output.push_str(&token_str);
458        output.push(' ');
459        if is_keyword { output.push(' '); }
460
461        first_token_start.get_or_insert(token_start);
462        prev_token_end = Some(token_end);
463    }
464    PrintOutput {
465        output,
466        start_token: first_token_start,
467        end_token: prev_token_end,
468    }
469}
470
471// ==================
472// === Eval Macro ===
473// ==================
474
475#[proc_macro]
476pub fn eval(input_raw: proc_macro::TokenStream) -> proc_macro::TokenStream {
477    let mut cfg = ProjectConfig::default();
478    let input = extract_dependencies(&mut cfg, input_raw.into());
479
480    let input_str = expand_output_macro(input).to_string();
481    let input_str_esc: String = input_str.chars().flat_map(|c| c.escape_default()).collect();
482    if DEBUG { println!("REWRITTEN INPUT: {input_str}"); }
483
484    let out_dir = get_output_dir();
485    let project_name = project_name_from_input(&input_str);
486    let project_dir = out_dir.join(&project_name);
487
488    if !project_dir.exists() {
489        fs::create_dir_all(&project_dir)
490            .expect("Failed to create project directory.");
491    }
492
493    let attrs = cfg.lib.print();
494    let main_content = format!(
495        "{attrs}
496        {PRELUDE}
497
498        const SOURCE_CODE: &str = \"{input_str_esc}\";
499
500        fn main() {{
501            let mut output_buffer = String::new();
502            let result = {{
503                {input_str}
504            }};
505            push_as_str(&mut output_buffer, &result);
506            println!(\"{{}}\", prefix_lines_with_output(&output_buffer));
507        }}",
508    );
509
510    create_project_skeleton(&project_dir, cfg, &main_content);
511    let output = run_cargo_project(&project_dir);
512    fs::remove_dir_all(&project_dir).ok();
513    let mut code = String::new();
514    for line in output.split('\n') {
515        let line_trimmed = line.trim();
516        if line_trimmed.starts_with(OUTPUT_PREFIX) {
517            code.push_str(&line_trimmed[OUTPUT_PREFIX.len()..]);
518            code.push('\n');
519        } else if line_trimmed.starts_with(WARNING_PREFIX) {
520            println!("[WARNING] {}", &line_trimmed[WARNING_PREFIX.len()..]);
521        } else if line_trimmed.starts_with(ERROR_PREFIX) {
522            println!("[ERROR] {}", &line_trimmed[ERROR_PREFIX.len()..]);
523        } else if line_trimmed.len() > 0 {
524            println!("{line}");
525        }
526    }
527
528    let out: TokenStream = code.parse().expect("Failed to parse generated code.");
529    if DEBUG {
530        println!("OUT: {out}");
531    }
532    out.into()
533}