use std::path::Path;
use syn::{Item, UseTree};
use crate::{
config::Config,
diagnostic::{Diagnostic, Severity},
rules::Rule,
};
const MAX_LINE_WIDTH: usize = 100;
const INDENT: &str = " ";
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum ImportGroup {
Std, External, Internal, }
#[derive(Debug, Clone)]
struct NormalizedUse {
visibility: String,
tree: UseNode,
group: ImportGroup,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum UseNode {
Path { ident: String, child: Box<UseNode> },
Name {
ident: String,
rename: Option<String>,
},
Slf { rename: Option<String> },
Glob,
Group { items: Vec<UseNode> },
}
pub struct ImportFormattingRule {
group: bool,
sort: bool,
merge: bool,
}
impl ImportFormattingRule {
pub fn new(group: bool, sort: bool, merge: bool) -> Self {
Self { group, sort, merge }
}
pub fn from_config(config: &Config) -> Self {
Self {
group: config.imports.group,
sort: config.imports.sort,
merge: config.imports.merge,
}
}
fn format_imports(&self, content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
let has_trailing_newline = content.ends_with('\n');
let Some((region_start, region_end)) = find_use_region(&lines) else {
return content.to_string();
};
let region_text: String = lines[region_start..=region_end]
.iter()
.map(|l| format!("{l}\n"))
.collect();
let mut imports = match parse_use_items(®ion_text) {
Some(imports) if !imports.is_empty() => imports,
_ => return content.to_string(),
};
if self.merge {
imports = merge_imports(imports);
}
if self.sort {
for imp in &mut imports {
sort_use_node(&mut imp.tree);
}
imports.sort_by(|a, b| {
if self.group {
a.group
.cmp(&b.group)
.then_with(|| cmp_use_nodes(&a.tree, &b.tree))
} else {
cmp_use_nodes(&a.tree, &b.tree)
}
});
}
let formatted = format_all_imports(&imports, self.group);
let mut result = String::new();
for line in &lines[..region_start] {
result.push_str(line);
result.push('\n');
}
result.push_str(&formatted);
let mut after_idx = region_end + 1;
while after_idx < lines.len() && lines[after_idx].trim().is_empty() {
after_idx += 1;
}
if after_idx < lines.len() {
result.push('\n'); for i in after_idx..lines.len() {
result.push_str(lines[i]);
if i < lines.len() - 1 {
result.push('\n');
}
}
}
if has_trailing_newline && !result.ends_with('\n') {
result.push('\n');
}
result
}
}
impl Rule for ImportFormattingRule {
fn id(&self) -> &str {
"RC1001"
}
fn name(&self) -> &str {
"ImportFormatting"
}
fn check(&self, content: &str, file: &Path) -> Vec<Diagnostic> {
let fixed = self.format_imports(content);
if fixed != *content {
vec![Diagnostic {
rule_id: self.id().to_string(),
message:
"Import statements are not properly formatted. Run `rustcop fix` to auto-fix."
.to_string(),
file: file.to_path_buf(),
line: 1,
severity: Severity::Warning,
}]
} else {
vec![]
}
}
fn fix(&self, content: &str) -> String {
self.format_imports(content)
}
}
fn parse_use_items(region_text: &str) -> Option<Vec<NormalizedUse>> {
let file: syn::File = syn::parse_str(region_text).ok()?;
let mut result = Vec::new();
for item in &file.items {
if let Item::Use(item_use) = item {
let vis = format_visibility(&item_use.vis);
let tree = use_tree_to_node(&item_use.tree);
let group = classify_node(&tree);
result.push(NormalizedUse {
visibility: vis,
tree,
group,
});
}
}
Some(result)
}
fn use_tree_to_node(tree: &UseTree) -> UseNode {
match tree {
UseTree::Path(p) => UseNode::Path {
ident: p.ident.to_string(),
child: Box::new(use_tree_to_node(&p.tree)),
},
UseTree::Name(n) => UseNode::Name {
ident: n.ident.to_string(),
rename: None,
},
UseTree::Rename(r) => UseNode::Name {
ident: r.ident.to_string(),
rename: Some(r.rename.to_string()),
},
UseTree::Glob(_) => UseNode::Glob,
UseTree::Group(g) => UseNode::Group {
items: g.items.iter().map(use_tree_to_node).collect(),
},
}
}
fn format_visibility(vis: &syn::Visibility) -> String {
match vis {
syn::Visibility::Public(_) => "pub".to_string(),
syn::Visibility::Restricted(r) => {
let path = r
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
if r.in_token.is_some() {
format!("pub(in {path})")
} else {
format!("pub({path})")
}
}
syn::Visibility::Inherited => String::new(),
}
}
fn classify_node(node: &UseNode) -> ImportGroup {
let root = root_ident(node);
match root.as_str() {
"std" | "core" | "alloc" => ImportGroup::Std,
"crate" | "self" | "super" => ImportGroup::Internal,
_ => ImportGroup::External,
}
}
fn root_ident(node: &UseNode) -> String {
match node {
UseNode::Path { ident, .. } => ident.clone(),
UseNode::Name { ident, .. } => ident.clone(),
UseNode::Slf { .. } => "self".to_string(),
UseNode::Glob => "*".to_string(),
UseNode::Group { items } => items.first().map(root_ident).unwrap_or_default(),
}
}
fn is_use_line(trimmed: &str) -> bool {
trimmed.starts_with("use ")
|| trimmed.starts_with("pub use ")
|| (trimmed.starts_with("pub(") && trimmed.contains(") use "))
}
fn find_use_region(lines: &[&str]) -> Option<(usize, usize)> {
let mut first_use: Option<usize> = None;
let mut last_use_end: usize = 0;
let mut i = 0;
let mut brace_depth: i32 = 0;
let mut found_any_use = false;
while i < lines.len() {
let trimmed = lines[i].trim();
if brace_depth > 0 {
for ch in trimmed.chars() {
match ch {
'{' => brace_depth += 1,
'}' => brace_depth -= 1,
_ => {}
}
}
last_use_end = i;
i += 1;
continue;
}
if is_use_line(trimmed) {
if first_use.is_none() {
first_use = Some(i);
}
found_any_use = true;
for ch in trimmed.chars() {
match ch {
'{' => brace_depth += 1,
'}' => brace_depth -= 1,
_ => {}
}
}
last_use_end = i;
} else if found_any_use && (trimmed.is_empty() || trimmed.starts_with("//")) {
} else if found_any_use {
break;
}
i += 1;
}
first_use.map(|start| (start, last_use_end))
}
fn sort_use_node(node: &mut UseNode) {
if let UseNode::Path { child, .. } = node {
sort_use_node(child);
}
if let UseNode::Group { items } = node {
for item in items.iter_mut() {
sort_use_node(item);
}
items.sort_by(cmp_use_nodes);
}
}
fn cmp_use_nodes(a: &UseNode, b: &UseNode) -> std::cmp::Ordering {
fn ident_case_category(s: &str) -> u8 {
if s.starts_with(|c: char| c.is_lowercase()) {
0 } else if s.starts_with(|c: char| c.is_uppercase()) {
if s.chars()
.all(|c| c.is_uppercase() || c == '_' || c.is_numeric())
{
2 } else {
1 }
} else {
1 }
}
fn sort_key(node: &UseNode) -> (u8, u8, String) {
match node {
UseNode::Slf { .. } => (0, 0, String::new()),
UseNode::Path { ident, child } if ident == "self" => (0, 0, node_sort_suffix(child)),
UseNode::Path { ident, child } if ident == "super" => (1, 0, node_sort_suffix(child)),
UseNode::Path { ident, child } if ident == "crate" => (2, 0, node_sort_suffix(child)),
UseNode::Path { ident, child } => {
let cat = ident_case_category(ident);
(3, cat, format!("{ident}::{}", node_sort_suffix(child)))
}
UseNode::Name { ident, .. } => {
let cat = ident_case_category(ident);
(3, cat, ident.clone())
}
UseNode::Glob => (4, 0, String::new()),
UseNode::Group { .. } => (5, 0, String::new()),
}
}
let (ka, ca, fa) = sort_key(a);
let (kb, cb, fb) = sort_key(b);
ka.cmp(&kb)
.then_with(|| ca.cmp(&cb))
.then_with(|| fa.cmp(&fb))
}
fn node_sort_suffix(node: &UseNode) -> String {
match node {
UseNode::Path { ident, child } => format!("{ident}::{}", node_sort_suffix(child)),
UseNode::Name { ident, .. } => ident.clone(),
UseNode::Slf { .. } => "self".to_string(),
UseNode::Glob => "*".to_string(),
UseNode::Group { items } => {
let inner: Vec<String> = items.iter().map(|i| node_sort_suffix(i)).collect();
format!("{{{}}}", inner.join(", "))
}
}
}
fn merge_imports(imports: Vec<NormalizedUse>) -> Vec<NormalizedUse> {
use std::collections::BTreeMap;
let mut by_key: BTreeMap<(String, String), Vec<NormalizedUse>> = BTreeMap::new();
for imp in imports {
let root = root_ident(&imp.tree);
let key = (imp.visibility.clone(), root);
by_key.entry(key).or_default().push(imp);
}
by_key
.into_values()
.map(|group| {
if group.len() == 1 {
return group.into_iter().next().unwrap();
}
let vis = group[0].visibility.clone();
let grp = group[0].group;
let mut all_children: Vec<UseNode> = Vec::new();
for imp in &group {
collect_children_for_merge(&imp.tree, &mut all_children);
}
all_children.dedup();
let root = root_ident(&group[0].tree);
let tree = if all_children.is_empty() {
UseNode::Name {
ident: root,
rename: None,
}
} else if all_children.len() == 1 {
UseNode::Path {
ident: root,
child: Box::new(all_children.into_iter().next().unwrap()),
}
} else {
UseNode::Path {
ident: root,
child: Box::new(UseNode::Group {
items: all_children,
}),
}
};
NormalizedUse {
visibility: vis,
tree,
group: grp,
}
})
.collect()
}
fn collect_children_for_merge(node: &UseNode, out: &mut Vec<UseNode>) {
match node {
UseNode::Path { child, .. } => match child.as_ref() {
UseNode::Group { items } => {
out.extend(items.iter().cloned());
}
other => {
out.push(other.clone());
}
},
UseNode::Name { rename, .. } => {
out.push(UseNode::Slf {
rename: rename.clone(),
});
}
_ => {}
}
}
fn format_all_imports(imports: &[NormalizedUse], group: bool) -> String {
let mut result = String::new();
let mut prev_group: Option<ImportGroup> = None;
for imp in imports {
if group {
if let Some(prev) = prev_group {
if prev != imp.group {
result.push('\n');
}
}
}
prev_group = Some(imp.group);
let vis_prefix = if imp.visibility.is_empty() {
String::new()
} else {
format!("{} ", imp.visibility)
};
let formatted = format_use_stmt(&imp.tree, &vis_prefix);
result.push_str(&formatted);
result.push('\n');
}
result
}
fn format_use_stmt(node: &UseNode, vis_prefix: &str) -> String {
let path_str = format_node_to_path(node);
let stmt = format!("{vis_prefix}use {path_str};");
if !stmt.contains('{') {
return stmt;
}
let brace_depth = max_brace_depth(node);
if brace_depth <= 1 && stmt.len() <= MAX_LINE_WIDTH {
return stmt;
}
format_use_stmt_multiline(node, vis_prefix)
}
fn max_brace_depth(node: &UseNode) -> usize {
match node {
UseNode::Group { items } => 1 + items.iter().map(max_brace_depth).max().unwrap_or(0),
UseNode::Path { child, .. } => max_brace_depth(child),
_ => 0,
}
}
fn format_node_to_path(node: &UseNode) -> String {
match node {
UseNode::Name {
ident,
rename: None,
} => ident.clone(),
UseNode::Name {
ident,
rename: Some(alias),
} => format!("{ident} as {alias}"),
UseNode::Slf { rename: None } => "self".to_string(),
UseNode::Slf {
rename: Some(alias),
} => format!("self as {alias}"),
UseNode::Glob => "*".to_string(),
UseNode::Path { ident, child } => {
format!("{ident}::{}", format_node_to_path(child))
}
UseNode::Group { items } => {
let inner: Vec<String> = items.iter().map(format_node_to_path).collect();
format!("{{{}}}", inner.join(", "))
}
}
}
fn format_use_stmt_multiline(node: &UseNode, vis_prefix: &str) -> String {
let mut result = format!("{vis_prefix}use ");
format_node_multiline(node, &mut result, 0);
result.push(';');
result
}
fn format_node_multiline(node: &UseNode, out: &mut String, indent_level: usize) {
match node {
UseNode::Path { ident, child } => {
out.push_str(ident);
out.push_str("::");
match child.as_ref() {
UseNode::Group { items } => {
format_group_multiline(items, out, indent_level);
}
_ => {
format_node_multiline(child, out, indent_level);
}
}
}
UseNode::Group { items } => {
format_group_multiline(items, out, indent_level);
}
_ => {
out.push_str(&format_node_to_path(node));
}
}
}
fn format_group_multiline(items: &[UseNode], out: &mut String, indent_level: usize) {
let child_indent = INDENT.repeat(indent_level + 1);
let close_indent = INDENT.repeat(indent_level);
let needs_multiline = items.iter().any(|item| {
contains_group(item) || {
let s = format_node_to_path(item);
child_indent.len() + s.len() + 1 > MAX_LINE_WIDTH
}
}) || {
let inner: Vec<String> = items.iter().map(format_node_to_path).collect();
let one_line_len = inner.join(", ").len() + 2;
indent_level * 4 + one_line_len + 10 > MAX_LINE_WIDTH
};
if !needs_multiline {
let inner: Vec<String> = items.iter().map(format_node_to_path).collect();
out.push('{');
out.push_str(&inner.join(", "));
out.push('}');
return;
}
out.push_str("{\n");
let mut i = 0;
while i < items.len() {
if contains_group(&items[i]) {
out.push_str(&child_indent);
format_node_multiline(&items[i], out, indent_level + 1);
out.push_str(",\n");
i += 1;
} else {
let mut line = child_indent.clone();
while i < items.len() && !contains_group(&items[i]) {
let s = format_node_to_path(&items[i]);
let addition = if line.len() == child_indent.len() {
s.clone()
} else {
format!(", {s}")
};
if line.len() + addition.len() + 1 > MAX_LINE_WIDTH
&& line.len() > child_indent.len()
{
out.push_str(&line);
out.push_str(",\n");
line = format!("{child_indent}{s}");
} else {
line.push_str(&addition);
}
i += 1;
}
if line.len() > child_indent.len() {
out.push_str(&line);
out.push_str(",\n");
}
}
}
out.push_str(&close_indent);
out.push('}');
}
fn contains_group(node: &UseNode) -> bool {
match node {
UseNode::Group { .. } => true,
UseNode::Path { child, .. } => contains_group(child),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bad_format_to_good_format() {
let input = r#"use criterion::{Criterion, criterion_group, criterion_main};
use documentdb_gateway_core::{
configuration::{CertInputType, CertificateOptions, DocumentDBSetupConfiguration},
postgres::{ conn_mgmt::{ run_request_with_retries, Connection, ConnectionPool, ConnectionSource, PgPoolSettings, QueryOptions, RequestOptions, }, ScopedTransaction, },
requests::request_tracker::RequestTracker,
};"#;
let expected = r#"use criterion::{criterion_group, criterion_main, Criterion};
use documentdb_gateway_core::{
configuration::{CertInputType, CertificateOptions, DocumentDBSetupConfiguration},
postgres::{
conn_mgmt::{
run_request_with_retries, Connection, ConnectionPool, ConnectionSource, PgPoolSettings,
QueryOptions, RequestOptions,
},
ScopedTransaction,
},
requests::request_tracker::RequestTracker,
};"#;
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result.trim(), expected.trim(), "\n\nGot:\n{result}");
}
#[test]
fn test_simple_single_line() {
let input = "use std::collections::HashMap;\n";
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result, input);
}
#[test]
fn test_sorting_within_braces() {
let input = "use std::{fmt, collections::HashMap, io};\n";
let expected = "use std::{collections::HashMap, fmt, io};\n";
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result, expected);
}
#[test]
fn test_grouping() {
let input = "use crate::foo;\nuse std::io;\nuse serde::Serialize;\n";
let expected = "use std::io;\n\nuse serde::Serialize;\n\nuse crate::foo;\n";
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result, expected);
}
#[test]
fn test_merge_same_root() {
let input = "use std::io;\nuse std::fmt;\n";
let expected = "use std::{fmt, io};\n";
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result, expected);
}
#[test]
fn test_pub_visibility() {
let input = "pub use serde::{Deserialize, Serialize};\n";
let rule = ImportFormattingRule::new(true, true, true);
let result = rule.format_imports(input);
assert_eq!(result, input);
}
}