cargo_group_imports/
lib.rs1#[cfg(test)]
2mod test;
3
4use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
5use std::hash::Hash;
6use std::io::Write;
7use std::path::{Path, PathBuf};
8
9use clap::Parser;
10use itertools::Itertools;
11use log::*;
12use tree_sitter::TreeCursor;
13
14#[derive(Parser)]
15#[command(name = "cargo")]
16#[command(bin_name = "cargo")]
17pub enum MainFlags {
18 GroupImports(Flags),
19}
20
21#[derive(Parser)]
34#[clap(about, version)]
35pub struct Flags {
36 #[clap(default_value_os_t = std::env::current_dir().unwrap())]
37 pub workspace: PathBuf,
38 #[clap(long)]
40 pub fix: bool,
41 #[clap(skip = true)]
42 pub rustfmt: bool,
43 #[clap(long, default_value_t = clap::ColorChoice::Auto)]
44 color: clap::ColorChoice,
45}
46impl Flags {
47 pub fn write_style(&self) -> env_logger::WriteStyle {
48 match self.color {
49 clap::ColorChoice::Auto => env_logger::WriteStyle::Auto,
50 clap::ColorChoice::Always => env_logger::WriteStyle::Always,
51 clap::ColorChoice::Never => env_logger::WriteStyle::Never,
52 }
53 }
54}
55
56#[derive(Default, Debug)]
57struct Use {
58 start: tree_sitter::Point,
59 end: tree_sitter::Point,
60 contents: String,
61 module: String,
62 module_decl: bool,
63}
64#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, Debug)]
66enum UseType {
67 Module,
68 Std,
69 External,
70 Workspace,
71 Crate,
72}
73
74fn node_as_utf8(node: tree_sitter::Node<'_>, source: &str) -> anyhow::Result<String> {
75 Ok(node.utf8_text(source.as_bytes())?.to_string())
76}
77fn process_line(cursor: &mut TreeCursor, source: &str) -> anyhow::Result<Use> {
79 let node = cursor.node();
80 let mut u = Use {
81 start: node.range().start_point,
82 end: node.range().end_point,
83 module_decl: node.kind() == "mod_item",
84 ..Default::default()
85 };
86 let mut contents = vec![node_as_utf8(node, source)?];
87 let mut sibling = node;
89 while let Some(s) = sibling.prev_sibling() {
90 sibling = s;
91
92 if ["line_comment", "attribute_item", "inner_attribute_item"].contains(&sibling.kind()) {
93 let content = node_as_utf8(sibling, source)?;
94 if !content.starts_with("//!") {
96 u.start = sibling.range().start_point;
97 contents.push(content);
98 }
99 } else {
100 break;
101 }
102 }
103 u.contents = contents.into_iter().rev().join("\n");
104 cursor.goto_first_child();
106 while cursor.goto_next_sibling() {
107 if [
108 "identifier",
109 "scoped_identifier",
110 "use_wildcard",
111 "use_as_clause",
112 "scoped_use_list",
113 ]
114 .contains(&cursor.node().kind())
115 {
116 u.module = node_as_utf8(cursor.node(), source)?
117 .split("::")
118 .find(|s| !s.is_empty())
119 .unwrap()
120 .to_string();
121 break;
122 }
123 }
124 cursor.goto_parent();
125 Ok(u)
126}
127
128pub type WorkspacePackages = HashMap<String, camino::Utf8PathBuf>;
129
130pub fn process_file(
132 filename: &Path,
133 package_name: &str,
134 workspace_packages: &WorkspacePackages,
135 args: &Flags,
136) -> anyhow::Result<bool> {
137 let mut parser = tree_sitter::Parser::new();
139 parser.set_language(tree_sitter_rust::language())?;
140 let source = std::fs::read_to_string(filename)?;
141 let tree = parser.parse(&source, None).unwrap();
142
143 let mut uses: Vec<Use> = vec![];
145 let mut mods_names = HashSet::<String>::default();
146 let mut macros_defs = HashSet::<String>::default();
147 let mut cursor = tree.walk();
148 cursor.goto_first_child();
149 loop {
150 let node = cursor.node();
151 if node.kind() == "macro_definition" {
152 cursor.goto_first_child();
153 cursor.goto_next_sibling();
154 macros_defs.insert(node_as_utf8(cursor.node(), &source)?);
155 cursor.goto_parent();
156 }
157 if node.kind() == "use_declaration" {
159 uses.push(process_line(&mut cursor, &source)?);
160 } else if node.kind() == "mod_item" {
161 let mut decl_list = false;
162 cursor.goto_first_child();
163 loop {
164 if cursor.node().kind() == "identifier" {
165 mods_names.insert(node_as_utf8(cursor.node(), &source)?);
166 } else if cursor.node().kind() == "declaration_list" {
167 decl_list = true;
168 }
169 if !cursor.goto_next_sibling() {
170 break;
171 }
172 }
173 cursor.goto_parent();
174 if !decl_list {
175 uses.push(process_line(&mut cursor, &source)?);
176 }
177 } else {
179 match uses.last() {
180 Some(u) if node.range().start_point.row == u.end.row => {
181 anyhow::bail!(
183 "use or mod expression on line {} contains another expression. This is unsupported.",
184 u.end.row
185 );
186 }
187 _ => {}
188 }
189 }
190 if !cursor.goto_next_sibling() {
191 break;
192 }
193 }
194 debug!("Macros: {:?}", macros_defs);
195 uses.retain(|u| !macros_defs.contains(&u.module));
197 debug!("Modules: {:?}", mods_names);
198
199 let mut grouped = BTreeMap::<UseType, Vec<&Use>>::default();
201 for u in &uses {
202 let import_type = if u.module == "std" {
203 UseType::Std
204 } else if u.module == package_name || u.module == "crate" || u.module == "super" {
205 UseType::Crate
206 } else if mods_names.contains(&u.module) || u.module_decl || u.module == "self" {
207 UseType::Module
208 } else if workspace_packages.contains_key(&u.module) {
209 UseType::Workspace
210 } else {
211 UseType::External
212 };
213 grouped.entry(import_type).or_default().push(u);
214 }
215 debug!("Grouped uses {:#?}", grouped);
216
217 let imports = grouped
219 .values()
220 .map(|uses| {
221 uses.iter()
222 .map(|u| &u.contents)
223 .chain(std::iter::once(&Default::default()))
224 .join("\n")
225 })
226 .join("\n");
227
228 let lines: BTreeSet<usize> = grouped
229 .values()
230 .flatten()
231 .flat_map(|l| (l.start.row..=l.end.row))
232 .collect();
233 let mut source_modified = source
234 .lines()
235 .enumerate()
236 .filter_map(|(i, l)| {
237 if lines.iter().next() == Some(&i) {
238 Some(imports.as_str())
239 } else if
240 lines.contains(&i)
242 ||
243 l.is_empty() && (i > 0 && lines.contains(&(i - 1)))
245 {
246 None
247 } else {
248 Some(l)
249 }
250 })
251 .chain(std::iter::once(""))
253 .join("\n");
254
255 let modified = source != source_modified;
259 if modified && args.rustfmt {
260 let mut cmd = std::process::Command::new("rustfmt")
261 .current_dir(&args.workspace)
262 .stdin(std::process::Stdio::piped())
263 .stdout(std::process::Stdio::piped())
264 .spawn()?;
265 let mut stdin = cmd.stdin.take().unwrap();
266 stdin.write_all(source_modified.as_bytes())?;
267 drop(stdin);
268 let out = cmd.wait_with_output()?;
269 anyhow::ensure!(out.status.success(), "Calling rustfmt failed");
270 source_modified = String::from_utf8(out.stdout)?;
271 }
272
273 let modified = source != source_modified;
275 if modified {
276 if !args.fix {
277 warn!(
278 "Diff in {:?}:\n{}",
279 filename,
280 prettydiff::diff_lines(&source, &source_modified).format_with_context(
281 Some(prettydiff::text::ContextConfig {
282 context_size: 5,
283 skipping_marker: "..."
284 }),
285 true
286 )
287 );
288 } else {
289 std::fs::write(filename, &source_modified)?;
290 info!("Wrote {:?}", filename);
291 }
292 }
293 Ok(modified)
294}