use std::collections::HashSet;
use syn::{
visit::Visit, Block, Expr, ExprAsync, ExprAwait, ExprBlock, ExprCall, ExprClosure,
ExprMethodCall, Item, ItemFn, Stmt, UseTree,
};
pub struct AsyncBoundaryDetector {
async_stack: Vec<bool>,
pub in_async_boundary: bool,
pub blocking_in_async: Vec<BlockingCall>,
imports: ImportTracker,
}
#[derive(Debug, Clone, Default)]
struct ImportTracker {
has_async_command: bool,
has_std_command: bool,
imported_symbols: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct BlockingCall {
pub function_name: String,
pub is_blocking: bool,
pub in_async_context: bool,
pub line: usize,
}
impl AsyncBoundaryDetector {
pub fn new() -> Self {
Self {
async_stack: vec![false],
in_async_boundary: false,
blocking_in_async: Vec::new(),
imports: ImportTracker::default(),
}
}
pub fn analyze_file(&mut self, file: &syn::File) {
self.analyze_imports(&file.items);
self.visit_file(file);
}
fn analyze_imports(&mut self, items: &[Item]) {
for item in items {
if let Item::Use(use_item) = item {
self.process_use_tree(&use_item.tree, String::new());
}
}
}
fn process_use_tree(&mut self, tree: &UseTree, prefix: String) {
match tree {
UseTree::Path(path) => {
let new_prefix = join_import_path(&prefix, &path.ident.to_string());
self.process_use_tree(&path.tree, new_prefix);
}
UseTree::Name(name) => {
let symbol = name.ident.to_string();
let full_path = join_import_path(&prefix, &symbol);
self.imports.record_named_import(full_path, symbol);
}
UseTree::Glob(_) => {
self.imports.record_glob_import(&prefix);
}
UseTree::Group(group) => {
for tree in &group.items {
self.process_use_tree(tree, prefix.clone());
}
}
UseTree::Rename(rename) => {
let symbol = rename.ident.to_string();
let alias = rename.rename.to_string();
let full_path = join_import_path(&prefix, &symbol);
self.imports.record_renamed_import(full_path, alias);
}
}
}
fn is_in_async(&self) -> bool {
self.async_stack.last().copied().unwrap_or(false)
}
fn push_async(&mut self, is_async: bool) {
self.async_stack.push(is_async || self.is_in_async());
}
fn pop_async(&mut self) {
self.async_stack.pop();
}
fn is_blocking_io(&self, path: &str, method: &str) -> bool {
let async_patterns = ["tokio::", "async_std::", "futures::", "smol::"];
for pattern in &async_patterns {
if path.starts_with(pattern) {
return false; }
}
if (method == "output" || method == "status" || method == "spawn") && path == "Command" {
if self.imports.has_async_command && !self.imports.has_std_command {
return false; } else if self.imports.has_std_command && !self.imports.has_async_command {
return true; } else {
return false;
}
}
let blocking_patterns = [
("std::fs", "read"),
("std::fs", "write"),
("std::fs", "read_to_string"),
("std::fs", "read_dir"),
("std::fs", "copy"),
("std::fs", "rename"),
("std::fs", "remove_file"),
("File", "open"),
("File", "create"),
("std::net", "TcpStream"),
("std::net", "TcpListener"),
("std::net", "UdpSocket"),
("std::process::Command", "output"),
("std::process::Command", "status"),
("std::process::Command", "spawn"),
("std::thread", "sleep"),
("thread", "sleep"),
("reqwest", "blocking"),
("ureq", "get"),
("ureq", "post"),
];
for (module, func) in &blocking_patterns {
if (path == *module
|| path.starts_with(&format!("{}::", module))
|| path.ends_with(&format!("::{}", module)))
&& method == *func
{
return true;
}
}
let blocking_methods = [
"read_to_string",
"read_to_end",
"read_exact",
"write_all",
"flush",
"sync_all",
"set_len",
"sleep",
"wait",
"join",
];
if blocking_methods.contains(&method) {
return path.is_empty() || path.starts_with("std::");
}
false
}
fn detect_async_boundary(&mut self, block: &Block) -> bool {
let mut has_await = false;
for stmt in &block.stmts {
match stmt {
Stmt::Expr(expr, _) => {
has_await = contains_await(expr);
if has_await {
break;
}
}
Stmt::Local(local) => {
if let Some(init) = &local.init {
has_await = contains_await(&init.expr);
if has_await {
break;
}
}
}
_ => {}
}
}
has_await
}
}
impl ImportTracker {
fn record_named_import(&mut self, full_path: String, symbol: String) {
self.record_command_import(&full_path, symbol);
self.imported_symbols.insert(full_path);
}
fn record_renamed_import(&mut self, full_path: String, alias: String) {
self.record_command_import(&full_path, alias);
}
fn record_glob_import(&mut self, prefix: &str) {
if is_async_command_module(prefix) {
self.has_async_command = true;
} else if is_std_command_module(prefix) {
self.has_std_command = true;
}
}
fn record_command_import(&mut self, full_path: &str, imported_symbol: String) {
if is_async_command_path(full_path) {
self.has_async_command = true;
self.imported_symbols.insert(imported_symbol);
} else if is_std_command_path(full_path) {
self.has_std_command = true;
self.imported_symbols.insert(imported_symbol);
}
}
}
fn join_import_path(prefix: &str, symbol: &str) -> String {
if prefix.is_empty() {
symbol.to_string()
} else {
format!("{prefix}::{symbol}")
}
}
fn is_async_command_path(path: &str) -> bool {
matches!(
path,
"tokio::process::Command" | "async_std::process::Command"
)
}
fn is_std_command_path(path: &str) -> bool {
path == "std::process::Command"
}
fn is_async_command_module(path: &str) -> bool {
matches!(path, "tokio::process" | "async_std::process")
}
fn is_std_command_module(path: &str) -> bool {
path == "std::process"
}
impl<'ast> Visit<'ast> for AsyncBoundaryDetector {
fn visit_item_fn(&mut self, node: &'ast ItemFn) {
let is_async = node.sig.asyncness.is_some();
self.push_async(is_async);
if is_async {
self.in_async_boundary = true;
}
syn::visit::visit_item_fn(self, node);
self.pop_async();
if is_async {
self.in_async_boundary = false;
}
}
fn visit_expr_async(&mut self, node: &'ast ExprAsync) {
self.push_async(true);
self.in_async_boundary = true;
syn::visit::visit_expr_async(self, node);
self.pop_async();
}
fn visit_expr_closure(&mut self, node: &'ast ExprClosure) {
let is_async = node.asyncness.is_some();
self.push_async(is_async);
syn::visit::visit_expr_closure(self, node);
self.pop_async();
}
fn visit_expr_call(&mut self, node: &'ast ExprCall) {
if self.is_in_async() {
if let Expr::Path(path) = &*node.func {
let path_str = path
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
let last_segment = path
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if self.is_blocking_io(&path_str, &last_segment) {
self.blocking_in_async.push(BlockingCall {
function_name: path_str,
is_blocking: true,
in_async_context: true,
line: 0, });
}
}
}
syn::visit::visit_expr_call(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
if self.is_in_async() {
let method_name = node.method.to_string();
let receiver_str = match &*node.receiver {
Expr::Path(path) => path
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::"),
Expr::Call(call) => {
if let Expr::Path(path) = &*call.func {
let full_path = path
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
if full_path.ends_with("::new") {
full_path
.strip_suffix("::new")
.unwrap_or(&full_path)
.to_string()
} else {
full_path
}
} else {
String::new()
}
}
_ => String::new(),
};
if self.is_blocking_io(&receiver_str, &method_name) {
self.blocking_in_async.push(BlockingCall {
function_name: format!("{}.{}", receiver_str, method_name),
is_blocking: true,
in_async_context: true,
line: 0,
});
}
}
syn::visit::visit_expr_method_call(self, node);
}
fn visit_expr_block(&mut self, node: &'ast ExprBlock) {
let has_boundary = self.detect_async_boundary(&node.block);
if has_boundary {
self.in_async_boundary = true;
}
syn::visit::visit_expr_block(self, node);
if has_boundary {
self.in_async_boundary = false;
}
}
}
fn contains_await(expr: &Expr) -> bool {
struct AwaitChecker {
has_await: bool,
}
impl<'ast> Visit<'ast> for AwaitChecker {
fn visit_expr_await(&mut self, _: &'ast ExprAwait) {
self.has_await = true;
}
}
let mut checker = AwaitChecker { has_await: false };
checker.visit_expr(expr);
checker.has_await
}
impl Default for AsyncBoundaryDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_blocking_io_detection() {
let detector = AsyncBoundaryDetector::new();
assert!(detector.is_blocking_io("std::fs", "read"));
assert!(detector.is_blocking_io("std::fs::read_to_string", "read_to_string"));
assert!(detector.is_blocking_io("File", "open"));
assert!(!detector.is_blocking_io("tokio::fs", "read"));
assert!(!detector.is_blocking_io("async_std::fs", "read"));
}
#[test]
fn test_async_context_detection() {
let code = r#"
async fn process_data() {
let data = std::fs::read_to_string("file.txt").unwrap();
process(data).await;
}
"#;
let file = syn::parse_file(code).unwrap();
let mut detector = AsyncBoundaryDetector::new();
for item in file.items {
if let syn::Item::Fn(func) = item {
detector.visit_item_fn(&func);
}
}
assert!(!detector.blocking_in_async.is_empty());
}
#[test]
fn grouped_std_command_import_is_tracked() {
let file = syn::parse_file("use std::{fs, process::Command};").unwrap();
let mut detector = AsyncBoundaryDetector::new();
detector.analyze_imports(&file.items);
assert!(detector.imports.has_std_command);
assert!(detector.imports.imported_symbols.contains("Command"));
assert!(detector
.imports
.imported_symbols
.contains("std::process::Command"));
}
#[test]
fn renamed_async_command_import_is_tracked_by_alias() {
let file = syn::parse_file("use tokio::process::Command as TokioCommand;").unwrap();
let mut detector = AsyncBoundaryDetector::new();
detector.analyze_imports(&file.items);
assert!(detector.imports.has_async_command);
assert!(detector.imports.imported_symbols.contains("TokioCommand"));
}
#[test]
fn glob_command_import_tracks_command_source() {
let file = syn::parse_file("use async_std::process::*;").unwrap();
let mut detector = AsyncBoundaryDetector::new();
detector.analyze_imports(&file.items);
assert!(detector.imports.has_async_command);
assert!(!detector.imports.has_std_command);
}
}