use im::{HashMap, HashSet};
use std::time::SystemTime;
use syn::{File, Item, ItemFn, ItemImpl, Path as SynPath};
#[derive(Debug, Clone)]
pub struct VisitorInfo {
pub trait_name: String,
pub method_name: String,
pub arm_count: usize,
pub is_exhaustive: bool,
pub confidence: f32,
}
#[derive(Debug, Clone)]
pub struct PatternInfo {
pub pattern_type: PatternType,
pub base_complexity: u32,
pub adjusted_complexity: u32,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PatternType {
Visitor,
ExhaustiveMatch,
SimpleMapping,
Standard,
}
#[derive(Debug, Clone)]
pub struct PatternCache {
pub file_hash: u64,
pub patterns: HashMap<String, PatternInfo>,
pub timestamp: SystemTime,
}
pub struct VisitorPatternDetector {
visitor_traits: HashSet<String>,
}
impl Default for VisitorPatternDetector {
fn default() -> Self {
Self::new()
}
}
impl VisitorPatternDetector {
pub fn new() -> Self {
let mut visitor_traits = HashSet::new();
visitor_traits.insert("Visit".to_string());
visitor_traits.insert("Visitor".to_string());
visitor_traits.insert("Fold".to_string());
visitor_traits.insert("VisitMut".to_string());
visitor_traits.insert("Walker".to_string());
visitor_traits.insert("Traverser".to_string());
Self { visitor_traits }
}
pub fn add_visitor_trait(&mut self, trait_name: String) {
self.visitor_traits.insert(trait_name);
}
pub fn detect_visitor_pattern(&mut self, file: &File, func: &ItemFn) -> Option<VisitorInfo> {
for item in &file.items {
if let Item::Impl(impl_block) = item {
if self.is_visitor_trait(impl_block) && self.contains_function(impl_block, func) {
return Some(self.analyze_visitor(func));
}
}
}
None
}
fn is_visitor_trait(&self, impl_block: &ItemImpl) -> bool {
if let Some((_, path, _)) = &impl_block.trait_ {
if let Some(trait_name) = self.extract_trait_name(path) {
if self.visitor_traits.contains(&trait_name) {
return true;
}
if trait_name.starts_with("Visit")
|| trait_name.starts_with("Visitor")
|| trait_name.starts_with("Fold")
{
return true;
}
}
}
false
}
fn extract_trait_name(&self, path: &SynPath) -> Option<String> {
path.segments.last().map(|seg| seg.ident.to_string())
}
fn contains_function(&self, impl_block: &ItemImpl, func: &ItemFn) -> bool {
let func_name = func.sig.ident.to_string();
impl_block.items.iter().any(|item| {
if let syn::ImplItem::Fn(method) = item {
method.sig.ident == func_name
} else {
false
}
})
}
fn analyze_visitor(&self, func: &ItemFn) -> VisitorInfo {
use syn::visit::Visit;
let mut visitor = MatchArmCounter::default();
visitor.visit_block(&func.block);
VisitorInfo {
trait_name: "Visit".to_string(), method_name: func.sig.ident.to_string(),
arm_count: visitor.max_arms,
is_exhaustive: visitor.has_wildcard,
confidence: if visitor.max_arms > 0 { 0.9 } else { 0.5 },
}
}
pub fn detect_visitor_by_pattern(&self, func: &ItemFn) -> Option<VisitorInfo> {
let name = func.sig.ident.to_string();
if name.starts_with("visit_")
|| name.starts_with("walk_")
|| name.starts_with("traverse_")
|| name.starts_with("fold_")
{
use syn::visit::Visit;
let mut visitor = MatchArmCounter::default();
visitor.visit_block(&func.block);
if visitor.max_arms >= 3 {
return Some(VisitorInfo {
trait_name: "Visitor".to_string(),
method_name: name,
arm_count: visitor.max_arms,
is_exhaustive: visitor.has_wildcard,
confidence: 0.8,
});
}
}
None
}
}
#[derive(Default)]
struct MatchArmCounter {
max_arms: usize,
has_wildcard: bool,
}
impl<'ast> syn::visit::Visit<'ast> for MatchArmCounter {
fn visit_expr_match(&mut self, match_expr: &'ast syn::ExprMatch) {
self.max_arms = self.max_arms.max(match_expr.arms.len());
for arm in &match_expr.arms {
if matches!(&arm.pat, syn::Pat::Wild(_) | syn::Pat::Ident(_)) {
self.has_wildcard = true;
}
}
syn::visit::visit_expr_match(self, match_expr);
}
}
#[derive(Debug, Clone)]
pub struct MatchCharacteristics {
pub pattern_type: PatternType,
pub arm_count: usize,
pub max_arm_complexity: u32,
pub is_simple_mapping: bool,
pub has_default: bool,
}
impl Default for MatchCharacteristics {
fn default() -> Self {
Self {
pattern_type: PatternType::Standard,
arm_count: 0,
max_arm_complexity: 0,
is_simple_mapping: false,
has_default: false,
}
}
}
pub struct MatchAnalyzer;
impl MatchAnalyzer {
pub fn analyze_match_pattern(&self, func: &ItemFn) -> MatchCharacteristics {
use syn::visit::Visit;
let mut visitor = MatchPatternVisitor::default();
visitor.visit_block(&func.block);
if visitor.match_count == 1 && visitor.is_primary_match {
let pattern_type = if visitor.is_simple_mapping {
PatternType::SimpleMapping
} else if visitor.all_arms_simple {
PatternType::ExhaustiveMatch
} else {
PatternType::Standard
};
MatchCharacteristics {
pattern_type,
arm_count: visitor.max_arms,
max_arm_complexity: visitor.max_arm_complexity,
is_simple_mapping: visitor.is_simple_mapping,
has_default: visitor.has_wildcard,
}
} else {
MatchCharacteristics::default()
}
}
}
#[derive(Default)]
struct MatchPatternVisitor {
match_count: usize,
max_arms: usize,
max_arm_complexity: u32,
is_simple_mapping: bool,
is_primary_match: bool,
has_wildcard: bool,
all_arms_simple: bool, }
impl<'ast> syn::visit::Visit<'ast> for MatchPatternVisitor {
fn visit_expr_match(&mut self, match_expr: &'ast syn::ExprMatch) {
self.match_count += 1;
self.max_arms = self.max_arms.max(match_expr.arms.len());
let all_simple = match_expr.arms.iter().all(|arm| {
matches!(
&*arm.body,
syn::Expr::Lit(_)
| syn::Expr::Path(_)
| syn::Expr::Return(_)
| syn::Expr::Break(_)
| syn::Expr::Continue(_)
)
});
self.all_arms_simple = all_simple;
if all_simple {
self.is_simple_mapping = true;
}
if match_expr.arms.len() >= 3 {
self.is_primary_match = true;
}
for arm in &match_expr.arms {
if matches!(&arm.pat, syn::Pat::Wild(_)) {
self.has_wildcard = true;
}
}
syn::visit::visit_expr_match(self, match_expr);
}
}
pub fn apply_pattern_scaling(base_complexity: u32, pattern: &PatternInfo) -> u32 {
match pattern.pattern_type {
PatternType::Visitor => {
let log_complexity = (base_complexity as f32).log2().ceil();
log_complexity.max(1.0) as u32
}
PatternType::ExhaustiveMatch => {
let sqrt_complexity = (base_complexity as f32).sqrt().ceil();
sqrt_complexity.max(2.0) as u32
}
PatternType::SimpleMapping => {
((base_complexity as f32) * 0.2).max(1.0) as u32
}
PatternType::Standard => base_complexity,
}
}
pub fn detect_visitor_pattern(file: &File, func: &ItemFn) -> Option<PatternInfo> {
let mut detector = VisitorPatternDetector::new();
if let Some(visitor_info) = detector.detect_visitor_pattern(file, func) {
let base = visitor_info.arm_count as u32;
let adjusted = apply_pattern_scaling(
base,
&PatternInfo {
pattern_type: PatternType::Visitor,
base_complexity: base,
adjusted_complexity: 0, confidence: visitor_info.confidence,
},
);
return Some(PatternInfo {
pattern_type: PatternType::Visitor,
base_complexity: base,
adjusted_complexity: adjusted,
confidence: visitor_info.confidence,
});
}
if let Some(visitor_info) = detector.detect_visitor_by_pattern(func) {
let base = visitor_info.arm_count as u32;
let adjusted = apply_pattern_scaling(
base,
&PatternInfo {
pattern_type: PatternType::Visitor,
base_complexity: base,
adjusted_complexity: 0,
confidence: visitor_info.confidence,
},
);
return Some(PatternInfo {
pattern_type: PatternType::Visitor,
base_complexity: base,
adjusted_complexity: adjusted,
confidence: visitor_info.confidence,
});
}
let analyzer = MatchAnalyzer;
let match_info = analyzer.analyze_match_pattern(func);
if match_info.arm_count >= 3 {
let base = match_info.arm_count as u32;
let pattern_type = match_info.pattern_type;
let adjusted = apply_pattern_scaling(
base,
&PatternInfo {
pattern_type: pattern_type.clone(),
base_complexity: base,
adjusted_complexity: 0,
confidence: 0.7,
},
);
return Some(PatternInfo {
pattern_type,
base_complexity: base,
adjusted_complexity: adjusted,
confidence: 0.7,
});
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_visitor_trait_detection() {
let detector = VisitorPatternDetector::new();
let impl_block: ItemImpl = parse_quote! {
impl Visit for MyVisitor {
fn visit_expr(&mut self, expr: &Expr) {
match expr {
Expr::Binary(b) => self.visit_binary(b),
Expr::Unary(u) => self.visit_unary(u),
Expr::Call(c) => self.visit_call(c),
_ => {}
}
}
}
};
assert!(detector.is_visitor_trait(&impl_block));
}
#[test]
fn test_visitor_by_pattern() {
let detector = VisitorPatternDetector::new();
let func: ItemFn = parse_quote! {
fn visit_expr(&mut self, expr: &Expr) {
match expr {
Expr::Binary(b) => {},
Expr::Unary(u) => {},
Expr::Call(c) => {},
Expr::Method(m) => {},
_ => {}
}
}
};
let result = detector.detect_visitor_by_pattern(&func);
assert!(result.is_some());
let info = result.unwrap();
assert_eq!(info.method_name, "visit_expr");
assert_eq!(info.arm_count, 5);
}
#[test]
fn test_logarithmic_scaling() {
let pattern = PatternInfo {
pattern_type: PatternType::Visitor,
base_complexity: 34,
adjusted_complexity: 0,
confidence: 0.9,
};
let adjusted = apply_pattern_scaling(34, &pattern);
assert_eq!(adjusted, 6);
}
#[test]
fn test_sqrt_scaling() {
let pattern = PatternInfo {
pattern_type: PatternType::ExhaustiveMatch,
base_complexity: 16,
adjusted_complexity: 0,
confidence: 0.7,
};
let adjusted = apply_pattern_scaling(16, &pattern);
assert_eq!(adjusted, 4);
}
#[test]
fn test_simple_mapping_scaling() {
let pattern = PatternInfo {
pattern_type: PatternType::SimpleMapping,
base_complexity: 10,
adjusted_complexity: 0,
confidence: 0.8,
};
let adjusted = apply_pattern_scaling(10, &pattern);
assert_eq!(adjusted, 2);
}
}