use crate::{ThagResult, BUILT_IN_CRATES};
use phf::phf_set;
use proc_macro2::TokenStream;
use quote::ToTokens;
use regex::Regex;
use std::collections::HashSet;
use std::ops::Deref;
use std::{
collections::HashMap,
hash::BuildHasher,
option::Option,
process::{self},
};
use strum::Display;
use syn::{
self, parse_file,
visit::Visit,
BinOp::{
AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, RemAssign,
ShlAssign, ShrAssign, SubAssign,
},
Expr, File, Item, ItemMod, ItemUse, ReturnType, Stmt,
Type::Tuple,
TypePath, UseRename, UseTree,
};
use thag_common::{debug_log, re, V};
use thag_profiler::profiled;
use thag_styling::{svprtln, Role};
#[cfg(debug_assertions)]
use {crate::debug_timings, std::time::Instant};
pub(crate) static FILTER_WORDS: phf::Set<&'static str> = phf_set! {
"f32", "f64",
"i8", "i16", "i32", "i64", "i128", "isize",
"u8", "u16", "u32", "u64", "u128", "usize",
"bool", "str",
"error", "fs",
"self", "super", "crate"
};
#[derive(Clone, Debug, Display)]
pub enum Ast {
File(syn::File),
Expr(syn::Expr),
}
impl Ast {
#[must_use]
#[profiled]
pub const fn is_file(&self) -> bool {
match self {
Self::File(_) => true,
Self::Expr(_) => false,
}
}
}
impl ToTokens for Ast {
#[profiled]
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
Self::File(file) => file.to_tokens(tokens),
Self::Expr(expr) => expr.to_tokens(tokens),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct CratesFinder {
pub crates: Vec<String>,
pub names_to_exclude: Vec<String>,
}
impl<'a> Visit<'a> for CratesFinder {
#[profiled]
fn visit_attribute(&mut self, attr: &'a syn::Attribute) {
match &attr.meta {
syn::Meta::Path(path) => {
if path.segments.len() > 1 {
if let Some(first_seg) = path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
debug_log!("visit_attribute (path) pushing {name} to crates");
self.crates.push(name);
}
}
}
}
syn::Meta::List(meta_list) => {
if meta_list.path.segments.len() > 1 {
if let Some(first_seg) = meta_list.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
debug_log!("visit_attribute (list) pushing {name} to crates");
self.crates.push(name);
}
}
}
}
syn::Meta::NameValue(meta_name_value) => {
if meta_name_value.path.segments.len() > 1 {
if let Some(first_seg) = meta_name_value.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
debug_log!("visit_attribute (name-value) pushing {name} to crates");
self.crates.push(name);
}
}
}
}
}
syn::visit::visit_attribute(self, attr);
}
#[profiled]
fn visit_item_use(&mut self, node: &'a ItemUse) {
if let UseTree::Rename(use_rename) = &node.tree {
let node_name = use_rename.ident.to_string();
self.crates.push(node_name);
} else {
syn::visit::visit_item_use(self, node);
}
}
#[profiled]
fn visit_use_tree(&mut self, node: &'a UseTree) {
match node {
UseTree::Group(_) => {
syn::visit::visit_use_tree(self, node);
}
UseTree::Path(p) => {
let node_name = p.ident.to_string();
if !should_filter_dependency(&node_name) && !self.crates.contains(&node_name) {
self.crates.push(node_name.clone());
}
let use_tree = &*p.tree;
match use_tree {
UseTree::Path(child) => {
let child_name = child.ident.to_string();
if child_name != node_name && !self.names_to_exclude.contains(&child_name)
{
self.names_to_exclude.push(child_name);
}
}
UseTree::Name(child) => {
let child_name = child.ident.to_string();
if child_name != node_name && !self.names_to_exclude.contains(&child_name)
{
self.names_to_exclude.push(child_name);
}
}
UseTree::Group(group) => {
for child in &group.items {
match child {
UseTree::Path(child) => {
let child_name = child.ident.to_string();
if child_name != node_name && !self.names_to_exclude.contains(&child_name)
{
self.names_to_exclude.push(child_name);
}
}
UseTree::Name(child) => {
let child_name = child.ident.to_string();
if child_name != node_name && !self.names_to_exclude.contains(&child_name)
{
self.names_to_exclude.push(child_name);
}
}
_ => (),
}
}
}
_ => (),
}
syn::visit::visit_use_tree(self, node);
}
UseTree::Name(n) => {
let node_name = n.ident.to_string();
if !self.crates.contains(&node_name) {
self.crates.push(node_name);
}
}
_ => (),
}
}
#[profiled]
fn visit_expr_path(&mut self, expr_path: &'a syn::ExprPath) {
if expr_path.path.segments.len() > 1 {
if let Some(first_seg) = expr_path.path.segments.first() {
let name = first_seg.ident.to_string();
#[cfg(debug_assertions)]
debug_log!("Found first seg {name} in expr_path={expr_path:#?}");
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_expr_path(self, expr_path);
}
#[profiled]
fn visit_type_path(&mut self, type_path: &'a TypePath) {
if type_path.path.segments.len() > 1 {
if let Some(first_seg) = type_path.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_type_path(self, type_path);
}
#[profiled]
fn visit_macro(&mut self, mac: &'a syn::Macro) {
if mac.path.segments.len() > 1 {
if let Some(first_seg) = mac.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_macro(self, mac);
}
#[profiled]
fn visit_item_impl(&mut self, item: &'a syn::ItemImpl) {
if let Some((_, path, _)) = &item.trait_ {
if let Some(first_seg) = path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
if let syn::Type::Path(type_path) = &*item.self_ty {
if let Some(first_seg) = type_path.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_item_impl(self, item);
}
#[profiled]
fn visit_item_type(&mut self, item: &'a syn::ItemType) {
if let syn::Type::Path(type_path) = &*item.ty {
if let Some(first_seg) = type_path.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_item_type(self, item);
}
#[profiled]
fn visit_type_param_bound(&mut self, bound: &'a syn::TypeParamBound) {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if let Some(first_seg) = trait_bound.path.segments.first() {
let name = first_seg.ident.to_string();
if !should_filter_dependency(&name) && !self.crates.contains(&name) {
self.crates.push(name);
}
}
}
syn::visit::visit_type_param_bound(self, bound);
}
}
#[derive(Clone, Debug, Default)]
pub struct MetadataFinder {
pub extern_crates: Vec<String>,
pub mods_to_exclude: Vec<String>,
pub names_to_exclude: Vec<String>,
pub main_count: usize,
}
impl<'a> Visit<'a> for MetadataFinder {
#[profiled]
fn visit_use_rename(&mut self, node: &'a UseRename) {
self.names_to_exclude.push(node.rename.to_string());
syn::visit::visit_use_rename(self, node);
}
#[profiled]
fn visit_item_extern_crate(&mut self, node: &'a syn::ItemExternCrate) {
let crate_name = node.ident.to_string();
self.extern_crates.push(crate_name);
syn::visit::visit_item_extern_crate(self, node);
}
#[profiled]
fn visit_item_mod(&mut self, node: &'a ItemMod) {
self.mods_to_exclude.push(node.ident.to_string());
syn::visit::visit_item_mod(self, node);
}
#[profiled]
fn visit_item_fn(&mut self, node: &'a syn::ItemFn) {
if node.sig.ident == "main" {
self.main_count += 1; }
syn::visit::visit_item_fn(self, node);
}
}
#[must_use]
#[allow(clippy::module_name_repetitions)]
#[profiled]
pub fn infer_deps_from_ast(
crates_finder: &CratesFinder,
metadata_finder: &MetadataFinder,
) -> Vec<String> {
let mut dependencies = vec![];
dependencies.extend_from_slice(&crates_finder.crates);
let to_remove: HashSet<String> = crates_finder
.names_to_exclude
.iter()
.cloned()
.chain(metadata_finder.names_to_exclude.iter().cloned())
.chain(metadata_finder.mods_to_exclude.iter().cloned())
.chain(BUILT_IN_CRATES.iter().map(Deref::deref).map(String::from))
.collect();
dependencies.retain(|e| !to_remove.contains(e));
for crate_name in &metadata_finder.extern_crates {
if !&to_remove.contains(crate_name) {
dependencies.push(crate_name.to_owned());
}
}
dependencies.sort();
dependencies.dedup();
dependencies
}
#[must_use]
#[profiled]
pub fn infer_deps_from_source(code: &str) -> Vec<String> {
if code.trim().is_empty() {
return vec![];
}
let maybe_ast = extract_and_wrap_uses(code);
let mut dependencies = maybe_ast.map_or_else(
|_| {
svprtln!(
Role::ERR,
V::QQ,
"Could not parse code into an abstract syntax tree"
);
vec![]
},
|ast| {
let crates_finder = find_crates(&ast);
let metadata_finder = find_metadata(&ast);
infer_deps_from_ast(&crates_finder, &metadata_finder)
},
);
let macro_use_regex: &Regex = re!(r"(?m)^[\s]*#\[macro_use\((\w+)\)");
let extern_crate_regex: &Regex = re!(r"(?m)^[\s]*extern\s+crate\s+([^;{]+)");
let modules = find_modules_source(code);
dependencies.retain(|e| !modules.contains(e));
for cap in macro_use_regex.captures_iter(code) {
let crate_name = cap[1].to_string();
if !modules.contains(&crate_name) {
dependencies.push(crate_name);
}
}
for cap in extern_crate_regex.captures_iter(code) {
let crate_name = cap[1].to_string();
if !modules.contains(&crate_name) {
dependencies.push(crate_name);
}
}
dependencies.sort();
dependencies
}
#[must_use]
#[profiled]
pub fn find_crates(syntax_tree: &Ast) -> CratesFinder {
let mut crates_finder = CratesFinder::default();
match syntax_tree {
Ast::File(ast) => crates_finder.visit_file(ast),
Ast::Expr(ast) => crates_finder.visit_expr(ast),
}
crates_finder
}
#[must_use]
#[profiled]
pub fn find_metadata(syntax_tree: &Ast) -> MetadataFinder {
let mut metadata_finder = MetadataFinder::default();
match syntax_tree {
Ast::File(ast) => metadata_finder.visit_file(ast),
Ast::Expr(ast) => metadata_finder.visit_expr(ast),
}
metadata_finder
}
#[must_use]
#[profiled]
pub fn should_filter_dependency(name: &str) -> bool {
if name.chars().next().is_some_and(char::is_uppercase) {
return true;
}
FILTER_WORDS.contains(name)
}
#[must_use]
#[profiled]
pub fn find_modules_source(code: &str) -> Vec<String> {
let module_regex: &Regex = re!(r"(?m)^[\s]*mod\s+([^;{\s]+)");
debug_log!("In ast::find_use_renames_source");
let mut modules: Vec<String> = vec![];
for cap in module_regex.captures_iter(code) {
let module = cap[1].to_string();
debug_log!("module={module}");
modules.push(module);
}
debug_log!("modules from source={modules:#?}");
modules
}
#[profiled]
pub fn extract_and_wrap_uses(source: &str) -> Result<Ast, syn::Error> {
let use_simple_regex: &Regex = re!(r"(?m)(^\s*use\s+[^;{]+;\s*$)");
let use_nested_regex: &Regex = re!(r"(?ms)(^\s*use\s+\{.*\};\s*$)");
let mut use_statements: Vec<String> = vec![];
for cap in use_simple_regex.captures_iter(source) {
let use_string = cap[1].to_string();
use_statements.push(use_string);
}
for cap in use_nested_regex.captures_iter(source) {
let use_string = cap[1].to_string();
use_statements.push(use_string);
}
let ast: File = parse_file(&use_statements.join("\n"))?;
Ok(Ast::File(ast))
}
#[profiled]
fn extract_functions(expr: &syn::Expr) -> HashMap<String, ReturnType> {
#[derive(Default)]
struct FindFns {
function_map: HashMap<String, ReturnType>,
}
impl<'ast> Visit<'ast> for FindFns {
#[profiled]
fn visit_item_fn(&mut self, i: &'ast syn::ItemFn) {
self.function_map
.insert(i.sig.ident.to_string(), i.sig.output.clone());
}
}
let mut finder = FindFns::default();
finder.visit_expr(expr);
finder.function_map
}
#[must_use]
#[inline]
#[profiled]
pub fn is_unit_return_type(expr: &Expr) -> bool {
#[cfg(debug_assertions)]
let start = Instant::now();
let function_map = extract_functions(expr);
let is_unit_type = is_last_stmt_unit_type(expr, &function_map);
#[cfg(debug_assertions)]
debug_timings(&start, "Determined probable snippet return type");
is_unit_type
}
#[allow(clippy::too_many_lines, clippy::unnecessary_map_or)]
#[must_use]
#[inline]
#[profiled]
pub fn is_last_stmt_unit_type<S: BuildHasher>(
expr: &Expr,
function_map: &HashMap<String, ReturnType, S>,
) -> bool {
match expr {
Expr::ForLoop(for_loop) => {
for_loop
.body
.stmts
.last()
.is_some_and(|last_stmt| is_stmt_unit_type(last_stmt, function_map))
}
Expr::If(expr_if) => {
if let Some(last_stmt_in_then_branch) = expr_if.then_branch.stmts.last() {
if !is_stmt_unit_type(last_stmt_in_then_branch, function_map) {
return false;
}
expr_if.else_branch.as_ref().map_or(true, |stmt| {
let expr_else = &*stmt.1;
match expr_else {
Expr::Block(expr_block) => {
let else_is_unit_type =
expr_block.block.stmts.last().is_some_and(|last_stmt_in_block| is_stmt_unit_type(last_stmt_in_block, function_map));
else_is_unit_type
}
Expr::If(_) => is_last_stmt_unit_type(expr_else, function_map),
expr => {
eprintln!("Possible logic error: expected else branch expression to be If or Block, found {expr:?}");
process::exit(1);
}
}
})
} else {
false
}
}
Expr::Block(expr_block) => {
if expr_block.block.stmts.is_empty() {
return true;
}
expr_block
.block
.stmts
.last()
.is_some_and(|last_stmt| is_stmt_unit_type(last_stmt, function_map))
}
Expr::Match(expr_match) => {
for arm in &expr_match.arms {
let expr = &*arm.body;
if is_last_stmt_unit_type(expr, function_map) {
continue;
}
return false;
}
true
}
Expr::Call(expr_call) => {
if let Expr::Path(path) = &*expr_call.func {
if let Some(value) = is_path_unit_type(path, function_map) {
return value;
}
}
false
}
Expr::Closure(ref expr_closure) => match &expr_closure.output {
ReturnType::Default => is_last_stmt_unit_type(&expr_closure.body, function_map),
ReturnType::Type(_, ty) => {
if let Tuple(tuple) = &**ty {
tuple.elems.is_empty()
} else {
false
}
}
},
Expr::MethodCall(expr_method_call) => {
is_last_stmt_unit_type(&expr_method_call.receiver, function_map)
}
Expr::Binary(expr_binary) => matches!(
expr_binary.op,
AddAssign(_)
| SubAssign(_)
| MulAssign(_)
| DivAssign(_)
| RemAssign(_)
| BitXorAssign(_)
| BitAndAssign(_)
| BitOrAssign(_)
| ShlAssign(_)
| ShrAssign(_)
),
Expr::While(_)
| Expr::Loop(_)
| Expr::Break(_)
| Expr::Continue(_)
| Expr::Infer(_)
| Expr::Let(_) => true,
Expr::Array(_)
| Expr::Assign(_)
| Expr::Async(_)
| Expr::Await(_)
| Expr::Cast(_)
| Expr::Const(_)
| Expr::Field(_)
| Expr::Group(_)
| Expr::Index(_)
| Expr::Lit(_)
| Expr::Paren(_)
| Expr::Range(_)
| Expr::Reference(_)
| Expr::Repeat(_)
| Expr::Struct(_)
| Expr::Try(_)
| Expr::TryBlock(_)
| Expr::Tuple(_)
| Expr::Unary(_)
| Expr::Unsafe(_)
| Expr::Verbatim(_)
| Expr::Yield(_) => false,
Expr::Macro(ref expr_macro) => {
if let Some(segment) = expr_macro.mac.path.segments.last() {
let ident = &segment.ident.to_string();
return ident.starts_with("print")
|| ident.starts_with("write")
|| ident.starts_with("debug");
}
false }
Expr::Path(ref path) => {
if let Some(value) = is_path_unit_type(path, function_map) {
return value;
}
false
}
Expr::Return(ref expr_return) => {
expr_return.expr.is_none()
}
_ => {
svprtln!(
Role::WARN,
V::Q,
"Expression not catered for: {expr:#?}, wrapping expression in println!()"
);
false
}
}
}
#[must_use]
#[inline]
#[profiled]
pub fn is_path_unit_type<S: BuildHasher>(
path: &syn::PatPath,
function_map: &HashMap<String, ReturnType, S>,
) -> Option<bool> {
if let Some(ident) = path.path.get_ident() {
if let Some(return_type) = function_map.get(&ident.to_string()) {
return Some(match return_type {
ReturnType::Default => {
true
}
ReturnType::Type(_, ty) => {
if let Tuple(tuple) = &**ty {
tuple.elems.is_empty()
} else {
false
}
}
});
}
}
None
}
#[profiled]
pub fn is_stmt_unit_type<S: BuildHasher>(
stmt: &Stmt,
function_map: &HashMap<String, ReturnType, S>,
) -> bool {
debug_log!("%%%%%%%% stmt={stmt:#?}");
match stmt {
Stmt::Expr(expr, None) => {
is_last_stmt_unit_type(expr, function_map)
} Stmt::Expr(expr, Some(_)) => {
match expr {
Expr::Return(expr_return) => {
debug_log!("%%%%%%%% expr_return={expr_return:#?}");
expr_return.expr.is_none()
}
Expr::Yield(expr_yield) => {
debug_log!("%%%%%%%% expr_yield={expr_yield:#?}");
expr_yield.expr.is_none()
}
_ => true,
}
} Stmt::Macro(m) => {
m.semi_token.is_some()
}
Stmt::Local(_) => true,
Stmt::Item(item) => match item {
Item::ExternCrate(_)
| Item::Fn(_)
| Item::ForeignMod(_)
| Item::Impl(_)
| Item::Struct(_)
| Item::Trait(_)
| Item::TraitAlias(_)
| Item::Type(_)
| Item::Union(_)
| Item::Use(_)
| Item::Mod(_) => true,
Item::Macro(m) => {
m.semi_token.is_some()
}
_ => false, },
}
}
#[profiled]
pub fn is_main_fn_returning_unit(file: &File) -> ThagResult<bool> {
for item in &file.items {
if let Item::Fn(func) = item {
if func.sig.ident == "main" {
let is_unit_return_type = matches!(func.sig.output, ReturnType::Default);
return Ok(is_unit_return_type);
}
}
}
Err("No main function found".into())
}