use std::collections::{HashMap, HashSet};
use syn::{
visit::Visit, Expr, GenericArgument, ImplItemFn, Item, PathArguments, Stmt, Type, TypePath,
};
pub struct ImportAnalyzer {
#[allow(dead_code)]
used_types: HashMap<String, HashSet<String>>,
type_mappings: HashMap<String, String>,
#[allow(dead_code)]
std_types: HashSet<String>,
type_aliases: HashMap<String, String>,
}
impl Default for ImportAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl ImportAnalyzer {
pub fn new() -> Self {
let mut std_types = HashSet::new();
std_types.insert("String".to_string());
std_types.insert("Vec".to_string());
std_types.insert("Option".to_string());
std_types.insert("Result".to_string());
std_types.insert("Box".to_string());
std_types.insert("Arc".to_string());
std_types.insert("Rc".to_string());
std_types.insert("HashMap".to_string());
std_types.insert("HashSet".to_string());
std_types.insert("BTreeMap".to_string());
std_types.insert("BTreeSet".to_string());
std_types.insert("VecDeque".to_string());
Self {
used_types: HashMap::new(),
type_mappings: HashMap::new(),
std_types,
type_aliases: HashMap::new(),
}
}
pub fn analyze_file(&mut self, file: &syn::File) {
for item in &file.items {
match item {
Item::Use(use_item) => {
self.extract_use_mapping(use_item);
}
Item::Struct(s) => {
self.type_mappings
.insert(s.ident.to_string(), format!("super::types::{}", s.ident));
}
Item::Enum(e) => {
self.type_mappings
.insert(e.ident.to_string(), format!("super::types::{}", e.ident));
}
Item::Type(t) => {
let alias_name = t.ident.to_string();
self.type_mappings
.insert(alias_name.clone(), format!("super::types::{}", t.ident));
let underlying_type = quote::quote!(#t).to_string();
self.type_aliases.insert(alias_name, underlying_type);
}
_ => {}
}
}
}
#[allow(dead_code)]
pub fn resolve_type_alias(&self, alias_name: &str) -> String {
self.type_aliases
.get(alias_name)
.cloned()
.unwrap_or_else(|| alias_name.to_string())
}
#[allow(dead_code)]
pub fn is_type_alias(&self, name: &str) -> bool {
self.type_aliases.contains_key(name)
}
fn extract_use_mapping(&mut self, use_item: &syn::ItemUse) {
let use_str = quote::quote!(#use_item).to_string();
if let Some(last_segment) = use_str.split("::").last() {
let type_name = last_segment.trim_end_matches(';').trim();
if !type_name.is_empty() && type_name.chars().next().is_some_and(|c| c.is_uppercase()) {
self.type_mappings.insert(
type_name.to_string(),
use_str
.replace("use ", "")
.trim_end_matches(';')
.trim()
.to_string(),
);
}
}
}
#[allow(dead_code)]
pub fn analyze_methods(&mut self, methods: &[&ImplItemFn]) {
for method in methods {
let mut visitor = TypeVisitor::new();
visitor.visit_impl_item_fn(method);
for type_name in visitor.types_used {
self.used_types
.entry(type_name.clone())
.or_default()
.insert("unknown".to_string());
}
}
}
#[allow(dead_code)]
pub fn generate_use_statements(&self, types_needed: &[String]) -> Vec<String> {
let mut use_statements = HashSet::new();
let mut std_collections = HashSet::new();
let mut crate_imports = HashSet::new();
let mut super_imports = HashSet::new();
for type_name in types_needed {
if self.is_primitive(type_name) {
continue;
}
if self.std_types.contains(type_name) {
if type_name == "HashMap"
|| type_name == "HashSet"
|| type_name == "VecDeque"
|| type_name == "BTreeMap"
|| type_name == "BTreeSet"
{
std_collections.insert(type_name.clone());
}
continue;
}
if let Some(path) = self.type_mappings.get(type_name) {
if path.starts_with("super::") {
super_imports.insert(path.clone());
} else if path.starts_with("crate::") {
crate_imports.insert(path.clone());
} else {
use_statements.insert(path.clone());
}
}
}
let mut result = Vec::new();
if !std_collections.is_empty() {
let collections: Vec<_> = std_collections.into_iter().collect();
result.push(format!(
"use std::collections::{{{}}};",
collections.join(", ")
));
}
if !super_imports.is_empty() {
for import in super_imports {
result.push(format!("use {};", import));
}
}
if !crate_imports.is_empty() {
for import in crate_imports {
result.push(format!("use {};", import));
}
}
for stmt in use_statements {
result.push(format!("use {};", stmt));
}
result.sort();
result
}
fn is_primitive(&self, type_name: &str) -> bool {
matches!(
type_name,
"i8" | "i16"
| "i32"
| "i64"
| "i128"
| "isize"
| "u8"
| "u16"
| "u32"
| "u64"
| "u128"
| "usize"
| "f32"
| "f64"
| "bool"
| "char"
| "str"
| "()"
)
}
#[allow(dead_code)]
pub fn infer_common_imports(&self) -> Vec<String> {
self.infer_imports_with_depth(1)
}
#[allow(dead_code)]
pub fn infer_imports_with_depth(&self, depth: usize) -> Vec<String> {
let super_prefix = "super::".repeat(depth);
vec![
"use std::collections::{HashMap, HashSet};".to_string(),
format!("use {}types::*;", super_prefix),
format!("use {}PropertyPathEvaluator;", super_prefix),
]
}
}
#[allow(dead_code)]
struct TypeVisitor {
types_used: HashSet<String>,
}
impl TypeVisitor {
fn new() -> Self {
Self {
types_used: HashSet::new(),
}
}
fn extract_type_name(&mut self, ty: &Type) {
match ty {
Type::Path(TypePath { path, .. }) => {
if let Some(segment) = path.segments.last() {
self.types_used.insert(segment.ident.to_string());
if let PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let GenericArgument::Type(inner_ty) = arg {
self.extract_type_name(inner_ty);
}
}
}
}
}
Type::Reference(r) => {
self.extract_type_name(&r.elem);
}
Type::Tuple(t) => {
for elem in &t.elems {
self.extract_type_name(elem);
}
}
_ => {}
}
}
}
impl<'ast> Visit<'ast> for TypeVisitor {
fn visit_type(&mut self, ty: &'ast Type) {
self.extract_type_name(ty);
syn::visit::visit_type(self, ty);
}
fn visit_expr(&mut self, expr: &'ast Expr) {
match expr {
Expr::MethodCall(method_call) => {
syn::visit::visit_expr(self, &method_call.receiver);
}
Expr::Path(path) => {
if let Some(segment) = path.path.segments.last() {
let name = segment.ident.to_string();
if name.chars().next().is_some_and(|c| c.is_uppercase()) {
self.types_used.insert(name);
}
}
}
_ => {}
}
syn::visit::visit_expr(self, expr);
}
fn visit_stmt(&mut self, stmt: &'ast Stmt) {
if let Stmt::Local(local) = stmt {
if let Some(init) = &local.init {
syn::visit::visit_expr(self, &init.expr);
}
}
syn::visit::visit_stmt(self, stmt);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_import_analyzer_std_types() {
let analyzer = ImportAnalyzer::new();
assert!(analyzer.std_types.contains("String"));
assert!(analyzer.std_types.contains("HashMap"));
}
#[test]
fn test_primitive_detection() {
let analyzer = ImportAnalyzer::new();
assert!(analyzer.is_primitive("i32"));
assert!(analyzer.is_primitive("bool"));
assert!(!analyzer.is_primitive("String"));
}
#[test]
fn test_generate_use_statements() {
let analyzer = ImportAnalyzer::new();
let types = vec!["i32".to_string(), "String".to_string()];
let statements = analyzer.generate_use_statements(&types);
assert!(statements.is_empty() || statements.iter().all(|s| !s.contains("i32")));
}
}