Skip to main content

cargo_group_imports/
lib.rs

1#[cfg(test)]
2mod test;
3
4use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
5use std::hash::Hash;
6use std::io::Write;
7use std::path::{Path, PathBuf};
8
9use clap::Parser;
10use itertools::Itertools;
11use log::*;
12use tree_sitter::TreeCursor;
13
14#[derive(Parser)]
15#[command(name = "cargo")]
16#[command(bin_name = "cargo")]
17pub enum MainFlags {
18    GroupImports(Flags),
19}
20
21/// Group imports in workspace source files.
22///
23/// This roughly corresponds to the `group_imports` unstable rustfmt option, with the difference
24/// that `rustfmt` does not distinguish workspace crates from external ones.
25///
26/// By default, displays a diff without applying changes. Returns code 0 when no changes are
27/// necessary.
28/// The --fix flag allows applying the changes.
29///
30/// See
31/// https://rust-lang.github.io/rustfmt/?version=v1.4.38&search=#group_imports
32/// https://github.com/rust-lang/rustfmt/blob/master/src/reorder.rs
33#[derive(Parser)]
34#[clap(about, version)]
35pub struct Flags {
36    #[clap(default_value_os_t = std::env::current_dir().unwrap())]
37    pub workspace: PathBuf,
38    /// Apply changes
39    #[clap(long)]
40    pub fix: bool,
41    #[clap(skip = true)]
42    pub rustfmt: bool,
43    #[clap(long, default_value_t = clap::ColorChoice::Auto)]
44    color: clap::ColorChoice,
45}
46impl Flags {
47    pub fn write_style(&self) -> env_logger::WriteStyle {
48        match self.color {
49            clap::ColorChoice::Auto => env_logger::WriteStyle::Auto,
50            clap::ColorChoice::Always => env_logger::WriteStyle::Always,
51            clap::ColorChoice::Never => env_logger::WriteStyle::Never,
52        }
53    }
54}
55
56#[derive(Default, Debug)]
57struct Use {
58    start: tree_sitter::Point,
59    end: tree_sitter::Point,
60    contents: String,
61    module: String,
62    module_decl: bool,
63}
64/// This defines the order
65#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, Debug)]
66enum UseType {
67    Module,
68    Std,
69    External,
70    Workspace,
71    Crate,
72}
73
74fn node_as_utf8(node: tree_sitter::Node<'_>, source: &str) -> anyhow::Result<String> {
75    Ok(node.utf8_text(source.as_bytes())?.to_string())
76}
77/// Process a `use` or `mod` line, extracting the module name, comments, and attributes.
78fn process_line(cursor: &mut TreeCursor, source: &str) -> anyhow::Result<Use> {
79    let node = cursor.node();
80    let mut u = Use {
81        start: node.range().start_point,
82        end: node.range().end_point,
83        module_decl: node.kind() == "mod_item",
84        ..Default::default()
85    };
86    let mut contents = vec![node_as_utf8(node, source)?];
87    // Include comments and cfg
88    let mut sibling = node;
89    while let Some(s) = sibling.prev_sibling() {
90        sibling = s;
91
92        if ["line_comment", "attribute_item", "inner_attribute_item"].contains(&sibling.kind()) {
93            let content = node_as_utf8(sibling, source)?;
94            // Don't take module-level comments along
95            if !content.starts_with("//!") {
96                u.start = sibling.range().start_point;
97                contents.push(content);
98            }
99        } else {
100            break;
101        }
102    }
103    u.contents = contents.into_iter().rev().join("\n");
104    // Find module
105    cursor.goto_first_child();
106    while cursor.goto_next_sibling() {
107        if [
108            "identifier",
109            "scoped_identifier",
110            "use_wildcard",
111            "use_as_clause",
112            "scoped_use_list",
113        ]
114        .contains(&cursor.node().kind())
115        {
116            u.module = node_as_utf8(cursor.node(), source)?
117                .split("::")
118                .find(|s| !s.is_empty())
119                .unwrap()
120                .to_string();
121            break;
122        }
123    }
124    cursor.goto_parent();
125    Ok(u)
126}
127
128pub type WorkspacePackages = HashMap<String, camino::Utf8PathBuf>;
129
130/// Returns whether the file has been changed, or would have been changed.
131pub fn process_file(
132    filename: &Path,
133    package_name: &str,
134    workspace_packages: &WorkspacePackages,
135    args: &Flags,
136) -> anyhow::Result<bool> {
137    // Phase 1: parse with tree-sitter
138    let mut parser = tree_sitter::Parser::new();
139    parser.set_language(tree_sitter_rust::language())?;
140    let source = std::fs::read_to_string(filename)?;
141    let tree = parser.parse(&source, None).unwrap();
142
143    // Phase 2: find `use` and `mod` statements
144    let mut uses: Vec<Use> = vec![];
145    let mut mods_names = HashSet::<String>::default();
146    let mut macros_defs = HashSet::<String>::default();
147    let mut cursor = tree.walk();
148    cursor.goto_first_child();
149    loop {
150        let node = cursor.node();
151        if node.kind() == "macro_definition" {
152            cursor.goto_first_child();
153            cursor.goto_next_sibling();
154            macros_defs.insert(node_as_utf8(cursor.node(), &source)?);
155            cursor.goto_parent();
156        }
157        // Use node
158        if node.kind() == "use_declaration" {
159            uses.push(process_line(&mut cursor, &source)?);
160        } else if node.kind() == "mod_item" {
161            let mut decl_list = false;
162            cursor.goto_first_child();
163            loop {
164                if cursor.node().kind() == "identifier" {
165                    mods_names.insert(node_as_utf8(cursor.node(), &source)?);
166                } else if cursor.node().kind() == "declaration_list" {
167                    decl_list = true;
168                }
169                if !cursor.goto_next_sibling() {
170                    break;
171                }
172            }
173            cursor.goto_parent();
174            if !decl_list {
175                uses.push(process_line(&mut cursor, &source)?);
176            }
177            // TODO: Look into sub-modules
178        } else {
179            match uses.last() {
180                Some(u) if node.range().start_point.row == u.end.row => {
181                    // Simplification for the deletion later.
182                    anyhow::bail!(
183                        "use or mod expression on line {} contains another expression. This is unsupported.",
184                        u.end.row
185                    );
186                }
187                _ => {}
188            }
189        }
190        if !cursor.goto_next_sibling() {
191            break;
192        }
193    }
194    debug!("Macros: {:?}", macros_defs);
195    // Special case of macros_rules declarations, where the pub use must be after the definition.
196    uses.retain(|u| !macros_defs.contains(&u.module));
197    debug!("Modules: {:?}", mods_names);
198
199    // Phase 3: Group imports
200    let mut grouped = BTreeMap::<UseType, Vec<&Use>>::default();
201    for u in &uses {
202        let import_type = if u.module == "std" {
203            UseType::Std
204        } else if u.module == package_name || u.module == "crate" || u.module == "super" {
205            UseType::Crate
206        } else if mods_names.contains(&u.module) || u.module_decl || u.module == "self" {
207            UseType::Module
208        } else if workspace_packages.contains_key(&u.module) {
209            UseType::Workspace
210        } else {
211            UseType::External
212        };
213        grouped.entry(import_type).or_default().push(u);
214    }
215    debug!("Grouped uses {:#?}", grouped);
216
217    // Phase 4: Insert into source file
218    let imports = grouped
219        .values()
220        .map(|uses| {
221            uses.iter()
222                .map(|u| &u.contents)
223                .chain(std::iter::once(&Default::default()))
224                .join("\n")
225        })
226        .join("\n");
227
228    let lines: BTreeSet<usize> = grouped
229        .values()
230        .flatten()
231        .flat_map(|l| (l.start.row..=l.end.row))
232        .collect();
233    let mut source_modified = source
234        .lines()
235        .enumerate()
236        .filter_map(|(i, l)| {
237            if lines.iter().next() == Some(&i) {
238                Some(imports.as_str())
239            } else if
240            // We ensured earlier that these lines do not contain anything else
241            lines.contains(&i)
242                ||
243            // Remove previous spacing
244            l.is_empty() && (i > 0 && lines.contains(&(i - 1)))
245            {
246                None
247            } else {
248                Some(l)
249            }
250        })
251        // New line at end
252        .chain(std::iter::once(""))
253        .join("\n");
254
255    // Phase 4: Run rustfmt; this should not be needed in most cases.
256    // TODO: Ensure it is not needed. The difference comes from the ordering of
257    // super::,crate:: etc. imports. Most of the runtime is due to running rustfmt.
258    let modified = source != source_modified;
259    if modified && args.rustfmt {
260        let mut cmd = std::process::Command::new("rustfmt")
261            .current_dir(&args.workspace)
262            .stdin(std::process::Stdio::piped())
263            .stdout(std::process::Stdio::piped())
264            .spawn()?;
265        let mut stdin = cmd.stdin.take().unwrap();
266        stdin.write_all(source_modified.as_bytes())?;
267        drop(stdin);
268        let out = cmd.wait_with_output()?;
269        anyhow::ensure!(out.status.success(), "Calling rustfmt failed");
270        source_modified = String::from_utf8(out.stdout)?;
271    }
272
273    // Phase 5: Write output or diff
274    let modified = source != source_modified;
275    if modified {
276        if !args.fix {
277            warn!(
278                "Diff in {:?}:\n{}",
279                filename,
280                prettydiff::diff_lines(&source, &source_modified).format_with_context(
281                    Some(prettydiff::text::ContextConfig {
282                        context_size: 5,
283                        skipping_marker: "..."
284                    }),
285                    true
286                )
287            );
288        } else {
289            std::fs::write(filename, &source_modified)?;
290            info!("Wrote {:?}", filename);
291        }
292    }
293    Ok(modified)
294}