magic_import/
lib.rs

1#![cfg_attr(not(feature = "stable"), feature(proc_macro_span))]
2
3use json::JsonValue;
4use proc_macro::TokenStream;
5use std::collections::{HashMap, HashSet};
6use std::process::Command;
7use std::str::FromStr;
8
9#[proc_macro]
10pub fn magic(input: TokenStream) -> TokenStream {
11    assert!(
12        input.is_empty(),
13        "auto_import::magic!() takes no arguments!"
14    );
15
16    #[cfg(feature = "stable")]
17    let key = {
18        use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
19        static ONCE: AtomicBool = AtomicBool::new(false);
20        if ONCE.swap(true, Relaxed) {
21            panic!("don't call `magic_import::magic!();` more than once per crate! (try using nightly?)");
22        }
23        if let Ok(imports) = std::env::var("autoimport") {
24            return TokenStream::from_str(&imports).unwrap();
25        }
26        "autoimport"
27    };
28
29    // need to keep file, key at the outermost scope for refs to live long enough
30    #[cfg(not(feature = "stable"))]
31    let file = {
32        use proc_macro::Span;
33        let file = Span::call_site().source_file();
34        if !file.is_real() {
35            // I don't know why this would ever be false or what a fake file even means, so don't handle it
36            return input;
37        }
38
39        // JSON output contains paths which ig is UTF-8 too. not quite sure what that's about.
40        // i think this'll panic with non-UTF8 stuff because of that, so therefore i assume valid UTF-8
41        file.path()
42            .into_os_string()
43            .into_string()
44            .expect("valid UTF-8")
45    };
46
47    #[cfg(not(feature = "stable"))]
48    let key: String = {
49        // uhh idk what's valid in env vars, from a quick google search it seems just alphanumeric and _ so better safe than sorry
50        "autoimport_"
51            .chars()
52            .chain(file.chars().filter(char::is_ascii_alphanumeric))
53            .collect()
54    };
55
56    #[cfg(not(feature = "stable"))]
57    let key: &str = {
58        use std::sync::Mutex;
59        lazy_static::lazy_static! {
60            static ref ONCE: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
61        }
62        let mut files = ONCE.lock().unwrap();
63        if files.contains(&key) {
64            // this poisons future invocations but uh, i guess that just prevents extra resources from being used for invalid invocations
65            panic!("don't call auto_import::magic!() more than once per file!");
66        }
67        files.insert(key.to_string());
68
69        if let Ok(imports) = std::env::var(&key) {
70            return TokenStream::from_str(&imports).unwrap();
71        }
72
73        // autoimport launched this process to check for errors, but this is NOT the correct invocation of the macro
74        if let Ok(_) = std::env::var("autoimport") {
75            return input;
76        }
77
78        &key
79    };
80
81    let mut imports = HashSet::<String>::new();
82    let mut more_imports = HashSet::<String>::new();
83    let mut excluded = HashSet::<String>::new();
84
85    for _ in 0..10 {
86        let mut args = std::env::args_os();
87        let out = Command::new(args.next().unwrap())
88            .args(args.filter(|arg| {
89                arg.to_str()
90                    .map_or(true, |s| !s.starts_with("--error-format="))
91            }))
92            .arg("--error-format=json")
93            .envs([
94                #[cfg(not(feature = "stable"))]
95                ("autoimport", "YES_SO_DONT_EVEN_TRY_ANYTHING"),
96                (
97                    key,
98                    &imports
99                        .iter()
100                        .flat_map(|s| ["use ", s, ";"])
101                        .collect::<String>(),
102                ),
103            ])
104            .output()
105            .unwrap();
106        if out.status.success() {
107            break;
108        }
109        for line in std::str::from_utf8(&out.stderr)
110            .unwrap()
111            .lines()
112            .filter(|l| l.starts_with('{'))
113        {
114            if let Ok(json) = json::parse(line) {
115                #[cfg(not(feature = "stable"))]
116                {
117                    if json["children"].members().chain([&json]).any(|c| {
118                        c["spans"].members().any(|span| {
119                            // assert_eq will contain "similarly named macro `assert` defined here"
120                            // with "is_primary": false, so therefore only check path for the
121                            span["is_primary"].as_bool().unwrap_or(false)
122                                && span["file_name"]
123                                    .as_str()
124                                    .map_or(false, |error_file| error_file != file)
125                        })
126                    }) {
127                        continue;
128                    }
129                }
130                more_imports.extend(
131                    error(&json)
132                        .into_iter()
133                        .filter(|&s| !imports.contains(s))
134                        .filter(|&s| !excluded.contains(s))
135                        .map(Into::into),
136                );
137            }
138        }
139
140        if more_imports.is_empty() {
141            break;
142        }
143
144        let mut idents: HashMap<String, Vec<String>> = HashMap::new();
145        for suggestion in more_imports.drain().chain(imports.drain()) {
146            let ident = suggestion.split("::").last().unwrap();
147            let suggestions_for_ident = idents.entry(ident.to_string()).or_default();
148            suggestions_for_ident.push(suggestion);
149        }
150        for (ident, suggestions) in idents {
151            let (best, exclude) = disambiguate(ident, suggestions);
152            imports.insert(best);
153            for bad in exclude {
154                excluded.insert(bad);
155            }
156        }
157    }
158    for import in &imports {
159        println!("\x1b[1;32m   Injecting\x1b[m use {import};");
160    }
161    TokenStream::from_str(
162        &imports
163            .iter()
164            .flat_map(|s| ["use ", s, ";"])
165            .collect::<String>(),
166    )
167    .unwrap()
168}
169
170fn error<'a>(json: &'a JsonValue) -> Vec<&'a str> {
171    if json["code"].is_null() {
172        let message = json["message"].as_str().unwrap_or_default();
173        if extract("cannot find macro `", message, "` in this scope").is_some() {
174            let message = json["children"][0]["message"].as_str().unwrap_or_default();
175            if let Some(suggestions) =
176                extract("consider importing one of these items:", message, "")
177            {
178                return suggestions
179                    .split_terminator("\n")
180                    .filter(|s| !s.is_empty())
181                    .collect();
182            } else if let Some(suggestion) =
183                extract("consider importing this macro:\n", message, "")
184            {
185                return vec![suggestion];
186            }
187        }
188    }
189    json["children"]
190        .members()
191        .flat_map(|c| {
192            c["spans"]
193                .members()
194                .map(|s| s["suggested_replacement"].as_str().unwrap_or_default())
195                .filter(|s| !s.is_empty())
196                .filter_map(|s| extract("use ", s.trim(), ";"))
197        })
198        .collect()
199}
200
201fn extract<'a>(start: &'static str, message: &'a str, end: &'static str) -> Option<&'a str> {
202    if message.starts_with(start) && message.ends_with(end) {
203        Some(&message[start.len()..(message.len() - end.len())])
204    } else {
205        None
206    }
207}
208
209fn disambiguate(ident: String, mut suggestions: Vec<String>) -> (String, Vec<String>) {
210    assert!(!suggestions.is_empty());
211    if suggestions.len() == 1 {
212        return (suggestions.remove(0), Vec::new());
213    }
214    for i in 0..(suggestions.len() - 1) {
215        for j in (i + 1)..suggestions.len() {
216            if std_and_core(&suggestions[i], &suggestions[j]) {
217                suggestions.swap_remove(j);
218                return disambiguate(ident, suggestions);
219            } else if std_and_core(&suggestions[j], &suggestions[i]) {
220                suggestions.swap_remove(i);
221                return disambiguate(ident, suggestions);
222            }
223        }
224    }
225
226    // 1. prelude first
227    // 2. stable over unstable
228    // 3. more common (such as std::ops) over uncommon (like collection-specific things)
229    // list the excluded things as well
230    const DEFAULTS: &[&str] = &[
231        // (more common)
232        // - std::collections::btree_map::Range
233        // - std::collections::btree_set::Range
234        "std::ops::Range",
235        // (prelude)
236        // - std::fmt::Result
237        // - std::io::Result
238        // - std::thread::Result
239        "std::result::Result",
240        // (prelude)
241        // - std::fmt::Error
242        // - std::io::Error
243        "std::error::Error",
244        // (unstable)
245        // - std::io::read_to_string
246        "std::fs::read_to_string",
247    ];
248
249    if let Some(index) = suggestions
250        .iter()
251        .position(|s| DEFAULTS.contains(&s.as_str()))
252    {
253        let result = suggestions.swap_remove(index);
254        return (result, suggestions);
255    }
256
257    use rand::prelude::*;
258    println!("\x1b[1;33m   Ambiguity\x1b[m for {ident}");
259
260    for suggestion in &suggestions {
261        println!("\x1b[1;31m            \x1b[m {suggestion}");
262    }
263
264    println!("\x1b[1;32m     Picking\x1b[m at random");
265    println!("\x1b[1;32m      Hoping\x1b[m for the best");
266    let index = (0..suggestions.len()).choose(&mut thread_rng()).unwrap();
267    let result = suggestions.swap_remove(index);
268    return (result, suggestions);
269}
270
271#[allow(non_upper_case_globals)]
272fn std_and_core(a: &str, b: &str) -> bool {
273    #[cfg(feature = "prefer_core")]
274    let (a, b) = (b, a);
275    const std: &str = "std::";
276    const core: &str = "core::";
277    a.starts_with(std) && b.starts_with(core) && a[std.len()..] == b[core.len()..]
278}