use syn::{Expr, Item, Local, Pat, Stmt, StmtMacro, UseTree, Visibility};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UseGroup {
Std,
External,
CrateLocal,
}
pub fn classify_use(item: &syn::ItemUse) -> UseGroup {
let root_ident = use_tree_root_ident(&item.tree);
match root_ident.as_deref() {
Some("std" | "alloc" | "core") => UseGroup::Std,
Some("crate" | "super" | "self") => UseGroup::CrateLocal,
_ => UseGroup::External,
}
}
fn use_tree_root_ident(tree: &UseTree) -> Option<String> {
match tree {
UseTree::Path(path) => Some(path.ident.to_string()),
UseTree::Name(name) => Some(name.ident.to_string()),
UseTree::Rename(rename) => Some(rename.ident.to_string()),
UseTree::Glob(_) => None,
UseTree::Group(group) => group.items.first().and_then(use_tree_root_ident),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ItemKind {
Use,
ExternCrate,
Mod,
Const,
Static,
TypeAlias,
Definition,
Other,
}
fn static_init_is_heavy(expr: &Expr) -> bool {
match expr {
Expr::Closure(_) | Expr::Block(_) | Expr::Async(_) | Expr::Unsafe(_) => true,
Expr::Call(c) => c.args.iter().any(static_init_is_heavy),
Expr::MethodCall(mc) => {
static_init_is_heavy(&mc.receiver) || mc.args.iter().any(static_init_is_heavy)
}
_ => false,
}
}
fn classify_item_kind(item: &Item) -> ItemKind {
match item {
Item::Use(_) => ItemKind::Use,
Item::ExternCrate(_) => ItemKind::ExternCrate,
Item::Mod(_) => ItemKind::Mod,
Item::Const(_) => ItemKind::Const,
Item::Static(s) => {
if static_init_is_heavy(&s.expr) {
ItemKind::Definition
} else {
ItemKind::Static
}
}
Item::Type(_) => ItemKind::TypeAlias,
Item::Fn(_)
| Item::Struct(_)
| Item::Enum(_)
| Item::Union(_)
| Item::Trait(_)
| Item::TraitAlias(_)
| Item::Impl(_)
| Item::Macro(_) => ItemKind::Definition,
_ => ItemKind::Other,
}
}
pub fn should_blank_between_items(prev: &Item, next: &Item) -> bool {
let pk = classify_item_kind(prev);
let nk = classify_item_kind(next);
match (pk, nk) {
(ItemKind::Use, ItemKind::Use) => {
let Item::Use(prev_use) = prev else { unreachable!() };
let Item::Use(next_use) = next else { unreachable!() };
let prev_group = classify_use(prev_use);
let next_group = classify_use(next_use);
let prev_cfg = prev_use.attrs.iter().any(|a| a.path().is_ident("cfg"));
let next_cfg = next_use.attrs.iter().any(|a| a.path().is_ident("cfg"));
prev_group != next_group || prev_cfg != next_cfg
}
(ItemKind::ExternCrate, ItemKind::ExternCrate) => false,
(ItemKind::Mod, ItemKind::Mod) => {
let prev_pub = matches!(prev, Item::Mod(m) if matches!(m.vis, Visibility::Public(_)));
let next_pub = matches!(next, Item::Mod(m) if matches!(m.vis, Visibility::Public(_)));
prev_pub != next_pub
}
(ItemKind::Const, ItemKind::Const) => false,
(ItemKind::Static, ItemKind::Static) => false,
(ItemKind::TypeAlias, ItemKind::TypeAlias) => false,
_ => true,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StmtWeight {
Light,
Binding,
Medium,
Heavy,
Item,
}
fn expr_is_heavy(expr: &Expr) -> bool {
match expr {
Expr::If(_)
| Expr::Match(_)
| Expr::ForLoop(_)
| Expr::While(_)
| Expr::Loop(_)
| Expr::Block(_)
| Expr::Unsafe(_)
| Expr::TryBlock(_) => true,
Expr::Closure(c) => matches!(*c.body, Expr::Block(_)),
Expr::Assign(a) => expr_is_heavy(&a.right),
_ => false,
}
}
fn pat_contains_mut(pat: &Pat) -> bool {
match pat {
Pat::Ident(p) => p.mutability.is_some(),
Pat::Reference(p) => p.mutability.is_some() || pat_contains_mut(&p.pat),
Pat::Struct(p) => p.fields.iter().any(|f| pat_contains_mut(&f.pat)),
Pat::Tuple(p) => p.elems.iter().any(pat_contains_mut),
Pat::TupleStruct(p) => p.elems.iter().any(pat_contains_mut),
Pat::Slice(p) => p.elems.iter().any(pat_contains_mut),
Pat::Or(p) => p.cases.iter().any(pat_contains_mut),
Pat::Type(p) => pat_contains_mut(&p.pat),
_ => false,
}
}
fn classify_local(local: &Local) -> StmtWeight {
if let Some(init) = &local.init {
if init.diverge.is_some() {
return StmtWeight::Heavy;
}
if expr_is_heavy(&init.expr) {
return StmtWeight::Heavy;
}
}
if pat_contains_mut(&local.pat) {
StmtWeight::Binding
} else {
StmtWeight::Light
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TracingLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
fn tracing_level_from_macro(mac: &syn::Macro) -> Option<TracingLevel> {
let name = mac.path.segments.last()?.ident.to_string();
match name.as_str() {
"trace" => Some(TracingLevel::Trace),
"debug" => Some(TracingLevel::Debug),
"info" => Some(TracingLevel::Info),
"warn" => Some(TracingLevel::Warn),
"error" => Some(TracingLevel::Error),
_ => None,
}
}
fn stmt_is_tracing(stmt: &Stmt) -> Option<TracingLevel> {
match stmt {
Stmt::Expr(Expr::Macro(m), _) => tracing_level_from_macro(&m.mac),
Stmt::Macro(StmtMacro { mac, .. }) => tracing_level_from_macro(mac),
_ => None,
}
}
fn tracing_blank_line(prev: &Stmt, next: &Stmt) -> Option<bool> {
if let Some(level) = stmt_is_tracing(prev) {
return Some(match level {
TracingLevel::Trace => false,
TracingLevel::Debug => true,
TracingLevel::Info | TracingLevel::Warn | TracingLevel::Error => true,
});
}
if let Some(level) = stmt_is_tracing(next) {
return Some(match level {
TracingLevel::Trace => true,
TracingLevel::Debug => false,
TracingLevel::Info | TracingLevel::Warn | TracingLevel::Error => true,
});
}
None
}
fn classify_weight(stmt: &Stmt) -> StmtWeight {
match stmt {
Stmt::Local(local) => classify_local(local),
Stmt::Expr(expr, _) => {
if expr_is_heavy(expr) {
StmtWeight::Heavy
} else {
StmtWeight::Medium
}
}
Stmt::Item(_) => StmtWeight::Item,
Stmt::Macro(_) => StmtWeight::Medium,
}
}
fn is_jump_stmt(stmt: &Stmt) -> bool {
matches!(
stmt,
Stmt::Expr(
Expr::Return(_) | Expr::Continue(_) | Expr::Break(_),
_
)
)
}
fn is_let_else(stmt: &Stmt) -> bool {
if let Stmt::Local(local) = stmt {
if let Some(init) = &local.init {
return init.diverge.is_some();
}
}
false
}
fn should_blank_between_stmts(prev: &Stmt, next: &Stmt) -> bool {
if let Some(decision) = tracing_blank_line(prev, next) {
return decision;
}
if is_jump_stmt(next) {
return true;
}
if is_let_else(next) {
return true;
}
let pw = classify_weight(prev);
let nw = classify_weight(next);
if nw == StmtWeight::Item || pw == StmtWeight::Item {
return true;
}
if pw == StmtWeight::Heavy || nw == StmtWeight::Heavy {
return true;
}
if pw == nw {
return false;
}
true
}
pub fn stmt_blank_lines(stmts: &[Stmt]) -> Vec<bool> {
let len = stmts.len();
let mut blanks = vec![false; len];
if len <= 1 {
return blanks;
}
for i in 1..len {
blanks[i] = should_blank_between_stmts(&stmts[i - 1], &stmts[i]);
}
if len > 1 {
if let syn::Stmt::Expr(_, None) = &stmts[len - 1] {
blanks[len - 1] = true;
}
}
blanks
}