use std::collections::{HashMap, HashSet};
use syn::{visit::Visit, Expr, ExprCall, ExprPath, Item, ItemTrait};
#[derive(Debug, Default)]
pub struct TraitMethodTracker {
trait_methods: HashMap<String, HashSet<String>>,
method_to_trait: HashMap<String, String>,
trait_to_module: HashMap<String, String>,
}
impl TraitMethodTracker {
pub fn new() -> Self {
Self::default()
}
pub fn analyze_file(&mut self, file: &syn::File) {
for item in &file.items {
if let Item::Trait(trait_item) = item {
self.analyze_trait(trait_item);
}
}
}
fn analyze_trait(&mut self, trait_item: &ItemTrait) {
let trait_name = trait_item.ident.to_string();
let mut methods = HashSet::new();
for item in &trait_item.items {
if let syn::TraitItem::Fn(method) = item {
let method_name = method.sig.ident.to_string();
methods.insert(method_name.clone());
self.method_to_trait.insert(method_name, trait_name.clone());
}
}
self.trait_methods.insert(trait_name, methods);
}
pub fn register_trait_module(&mut self, trait_name: &str, module_name: &str) {
self.trait_to_module
.insert(trait_name.to_string(), module_name.to_string());
}
pub fn get_trait_for_method(&self, method_name: &str) -> Option<&String> {
self.method_to_trait.get(method_name)
}
pub fn get_trait_module(&self, trait_name: &str) -> Option<&String> {
self.trait_to_module.get(trait_name)
}
#[allow(dead_code)]
pub fn get_all_traits(&self) -> Vec<String> {
self.trait_methods.keys().cloned().collect()
}
#[allow(dead_code)]
pub fn is_trait_method(&self, method_name: &str) -> bool {
self.method_to_trait.contains_key(method_name)
}
pub fn get_required_trait_imports(
&self,
items: &[Item],
current_module: &str,
) -> HashMap<String, String> {
let mut required_imports = HashMap::new();
let mut collector = TraitMethodCallCollector::new();
for item in items {
collector.visit_item(item);
}
for method_name in collector.method_calls {
if let Some(trait_name) = self.get_trait_for_method(&method_name) {
if let Some(trait_module) = self.get_trait_module(trait_name) {
if trait_module != current_module {
required_imports.insert(trait_name.clone(), trait_module.clone());
}
}
}
}
required_imports
}
}
struct TraitMethodCallCollector {
method_calls: HashSet<String>,
}
impl TraitMethodCallCollector {
fn new() -> Self {
Self {
method_calls: HashSet::new(),
}
}
}
impl<'ast> Visit<'ast> for TraitMethodCallCollector {
fn visit_expr(&mut self, expr: &'ast Expr) {
if let Expr::Call(ExprCall { func, .. }) = expr {
if let Expr::Path(ExprPath { path, .. }) = &**func {
if path.segments.len() >= 2 {
if let Some(method_segment) = path.segments.last() {
self.method_calls.insert(method_segment.ident.to_string());
}
}
}
}
syn::visit::visit_expr(self, expr);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trait_method_tracker_creation() {
let tracker = TraitMethodTracker::new();
assert!(tracker.trait_methods.is_empty());
}
#[test]
fn test_trait_analysis() {
let code = r#"
trait MyTrait {
fn method_a(&self);
fn method_b(&self) -> i32;
}
"#;
let file = syn::parse_file(code).expect("Failed to parse");
let mut tracker = TraitMethodTracker::new();
tracker.analyze_file(&file);
assert!(tracker.is_trait_method("method_a"));
assert!(tracker.is_trait_method("method_b"));
assert!(!tracker.is_trait_method("unknown_method"));
assert_eq!(
tracker.get_trait_for_method("method_a"),
Some(&"MyTrait".to_string())
);
}
#[test]
fn test_required_imports() {
let trait_code = r#"
trait SimdOps {
fn simd_add();
fn simd_mul();
}
"#;
let caller_code = r#"
fn caller() {
f32::simd_add();
}
"#;
let trait_file = syn::parse_file(trait_code).expect("Failed to parse");
let caller_file = syn::parse_file(caller_code).expect("Failed to parse");
let mut tracker = TraitMethodTracker::new();
tracker.analyze_file(&trait_file);
tracker.register_trait_module("SimdOps", "functions");
let imports = tracker.get_required_trait_imports(&caller_file.items, "helpers");
assert!(imports.contains_key("SimdOps"));
assert_eq!(imports.get("SimdOps"), Some(&"functions".to_string()));
}
}