use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TraitCategory {
Visitor,
Serialization,
Iterator,
Async,
Comparison,
Error,
Standard,
Custom,
}
impl TraitCategory {
#[must_use]
pub fn default_weight(&self) -> f64 {
match self {
Self::Visitor => 0.1, Self::Serialization => 0.1, Self::Iterator => 0.3, Self::Async => 0.3, Self::Standard => 0.2, Self::Comparison => 0.2, Self::Error => 0.3, Self::Custom => 0.4, }
}
#[must_use]
pub fn description(&self) -> &'static str {
match self {
Self::Visitor => "AST/tree visitor",
Self::Serialization => "serialization",
Self::Iterator => "iterator/stream",
Self::Async => "async runtime",
Self::Standard => "standard library trait",
Self::Comparison => "comparison/ordering",
Self::Error => "error handling",
Self::Custom => "custom trait",
}
}
}
impl Default for TraitCategory {
fn default() -> Self {
Self::Custom
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MethodPattern {
Exact(String),
Prefix(String),
Suffix(String),
}
impl MethodPattern {
#[must_use]
pub fn matches(&self, method_name: &str) -> bool {
match self {
Self::Exact(name) => method_name == name,
Self::Prefix(prefix) => method_name.starts_with(prefix),
Self::Suffix(suffix) => method_name.ends_with(suffix),
}
}
}
#[derive(Debug, Clone)]
pub struct KnownTrait {
pub path: String,
pub aliases: Vec<String>,
pub method_patterns: Vec<MethodPattern>,
pub category: TraitCategory,
}
impl KnownTrait {
#[must_use]
pub fn matches_trait_name(&self, trait_name: &str) -> bool {
if self.path == trait_name {
return true;
}
if self.aliases.iter().any(|a| a == trait_name) {
return true;
}
let primary_name = self.path.rsplit("::").next().unwrap_or(&self.path);
trait_name == primary_name || trait_name.ends_with(&format!("::{}", primary_name))
}
#[must_use]
pub fn matches_method(&self, method_name: &str) -> bool {
self.method_patterns.iter().any(|p| p.matches(method_name))
}
}
#[derive(Debug, Clone)]
pub struct KnownTraitRegistry {
traits: HashMap<String, KnownTrait>,
}
impl KnownTraitRegistry {
#[must_use]
pub fn new() -> Self {
Self {
traits: HashMap::new(),
}
}
pub fn add(&mut self, known_trait: KnownTrait) {
self.traits.insert(known_trait.path.clone(), known_trait);
}
#[must_use]
pub fn get(&self, path: &str) -> Option<&KnownTrait> {
self.traits.get(path)
}
#[must_use]
pub fn find(&self, trait_name: &str) -> Option<&KnownTrait> {
self.traits
.values()
.find(|t| t.matches_trait_name(trait_name))
}
#[must_use]
pub fn categorize_trait(&self, trait_name: &str) -> TraitCategory {
self.find(trait_name)
.map(|t| t.category)
.unwrap_or(TraitCategory::Custom)
}
#[must_use]
pub fn method_weight(&self, trait_name: &str) -> f64 {
self.categorize_trait(trait_name).default_weight()
}
pub fn iter(&self) -> impl Iterator<Item = &KnownTrait> {
self.traits.values()
}
#[must_use]
pub fn len(&self) -> usize {
self.traits.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.traits.is_empty()
}
}
impl Default for KnownTraitRegistry {
fn default() -> Self {
build_default_registry()
}
}
fn build_default_registry() -> KnownTraitRegistry {
let mut registry = KnownTraitRegistry::new();
registry.add(KnownTrait {
path: "syn::Visit".into(),
aliases: vec![
"syn::visit::Visit".into(),
"Visit".into(),
"syn::Visit<'ast>".into(),
],
method_patterns: vec![MethodPattern::Prefix("visit_".into())],
category: TraitCategory::Visitor,
});
registry.add(KnownTrait {
path: "syn::VisitMut".into(),
aliases: vec!["syn::visit_mut::VisitMut".into(), "VisitMut".into()],
method_patterns: vec![MethodPattern::Prefix("visit_".into())],
category: TraitCategory::Visitor,
});
registry.add(KnownTrait {
path: "syn::Fold".into(),
aliases: vec!["syn::fold::Fold".into(), "Fold".into()],
method_patterns: vec![MethodPattern::Prefix("fold_".into())],
category: TraitCategory::Visitor,
});
registry.add(KnownTrait {
path: "serde::Serialize".into(),
aliases: vec!["Serialize".into(), "serde::ser::Serialize".into()],
method_patterns: vec![MethodPattern::Exact("serialize".into())],
category: TraitCategory::Serialization,
});
registry.add(KnownTrait {
path: "serde::Deserialize".into(),
aliases: vec!["Deserialize".into(), "serde::de::Deserialize".into()],
method_patterns: vec![MethodPattern::Exact("deserialize".into())],
category: TraitCategory::Serialization,
});
registry.add(KnownTrait {
path: "serde::Serializer".into(),
aliases: vec!["Serializer".into()],
method_patterns: vec![MethodPattern::Prefix("serialize_".into())],
category: TraitCategory::Serialization,
});
registry.add(KnownTrait {
path: "serde::Deserializer".into(),
aliases: vec!["Deserializer".into()],
method_patterns: vec![MethodPattern::Prefix("deserialize_".into())],
category: TraitCategory::Serialization,
});
registry.add(KnownTrait {
path: "Iterator".into(),
aliases: vec!["std::iter::Iterator".into(), "core::iter::Iterator".into()],
method_patterns: vec![
MethodPattern::Exact("next".into()),
MethodPattern::Exact("size_hint".into()),
MethodPattern::Exact("count".into()),
MethodPattern::Exact("last".into()),
MethodPattern::Exact("nth".into()),
],
category: TraitCategory::Iterator,
});
registry.add(KnownTrait {
path: "IntoIterator".into(),
aliases: vec![
"std::iter::IntoIterator".into(),
"core::iter::IntoIterator".into(),
],
method_patterns: vec![MethodPattern::Exact("into_iter".into())],
category: TraitCategory::Iterator,
});
registry.add(KnownTrait {
path: "FromIterator".into(),
aliases: vec![
"std::iter::FromIterator".into(),
"core::iter::FromIterator".into(),
],
method_patterns: vec![MethodPattern::Exact("from_iter".into())],
category: TraitCategory::Iterator,
});
registry.add(KnownTrait {
path: "ExactSizeIterator".into(),
aliases: vec![
"std::iter::ExactSizeIterator".into(),
"core::iter::ExactSizeIterator".into(),
],
method_patterns: vec![MethodPattern::Exact("len".into())],
category: TraitCategory::Iterator,
});
registry.add(KnownTrait {
path: "DoubleEndedIterator".into(),
aliases: vec![
"std::iter::DoubleEndedIterator".into(),
"core::iter::DoubleEndedIterator".into(),
],
method_patterns: vec![
MethodPattern::Exact("next_back".into()),
MethodPattern::Exact("nth_back".into()),
],
category: TraitCategory::Iterator,
});
registry.add(KnownTrait {
path: "Future".into(),
aliases: vec![
"std::future::Future".into(),
"core::future::Future".into(),
"futures::Future".into(),
],
method_patterns: vec![MethodPattern::Exact("poll".into())],
category: TraitCategory::Async,
});
registry.add(KnownTrait {
path: "Stream".into(),
aliases: vec![
"futures::Stream".into(),
"futures_core::Stream".into(),
"tokio_stream::Stream".into(),
],
method_patterns: vec![
MethodPattern::Exact("poll_next".into()),
MethodPattern::Exact("size_hint".into()),
],
category: TraitCategory::Async,
});
registry.add(KnownTrait {
path: "Sink".into(),
aliases: vec!["futures::Sink".into(), "futures_sink::Sink".into()],
method_patterns: vec![
MethodPattern::Exact("poll_ready".into()),
MethodPattern::Exact("start_send".into()),
MethodPattern::Exact("poll_flush".into()),
MethodPattern::Exact("poll_close".into()),
],
category: TraitCategory::Async,
});
registry.add(KnownTrait {
path: "PartialEq".into(),
aliases: vec!["std::cmp::PartialEq".into(), "core::cmp::PartialEq".into()],
method_patterns: vec![
MethodPattern::Exact("eq".into()),
MethodPattern::Exact("ne".into()),
],
category: TraitCategory::Comparison,
});
registry.add(KnownTrait {
path: "Eq".into(),
aliases: vec!["std::cmp::Eq".into(), "core::cmp::Eq".into()],
method_patterns: vec![],
category: TraitCategory::Comparison,
});
registry.add(KnownTrait {
path: "PartialOrd".into(),
aliases: vec![
"std::cmp::PartialOrd".into(),
"core::cmp::PartialOrd".into(),
],
method_patterns: vec![
MethodPattern::Exact("partial_cmp".into()),
MethodPattern::Exact("lt".into()),
MethodPattern::Exact("le".into()),
MethodPattern::Exact("gt".into()),
MethodPattern::Exact("ge".into()),
],
category: TraitCategory::Comparison,
});
registry.add(KnownTrait {
path: "Ord".into(),
aliases: vec!["std::cmp::Ord".into(), "core::cmp::Ord".into()],
method_patterns: vec![
MethodPattern::Exact("cmp".into()),
MethodPattern::Exact("max".into()),
MethodPattern::Exact("min".into()),
MethodPattern::Exact("clamp".into()),
],
category: TraitCategory::Comparison,
});
registry.add(KnownTrait {
path: "Hash".into(),
aliases: vec!["std::hash::Hash".into(), "core::hash::Hash".into()],
method_patterns: vec![
MethodPattern::Exact("hash".into()),
MethodPattern::Exact("hash_slice".into()),
],
category: TraitCategory::Comparison,
});
registry.add(KnownTrait {
path: "Error".into(),
aliases: vec!["std::error::Error".into()],
method_patterns: vec![
MethodPattern::Exact("source".into()),
MethodPattern::Exact("description".into()),
MethodPattern::Exact("cause".into()),
],
category: TraitCategory::Error,
});
registry.add(KnownTrait {
path: "Default".into(),
aliases: vec![
"std::default::Default".into(),
"core::default::Default".into(),
],
method_patterns: vec![MethodPattern::Exact("default".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Clone".into(),
aliases: vec!["std::clone::Clone".into(), "core::clone::Clone".into()],
method_patterns: vec![
MethodPattern::Exact("clone".into()),
MethodPattern::Exact("clone_from".into()),
],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Drop".into(),
aliases: vec!["std::ops::Drop".into(), "core::ops::Drop".into()],
method_patterns: vec![MethodPattern::Exact("drop".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Display".into(),
aliases: vec!["std::fmt::Display".into(), "core::fmt::Display".into()],
method_patterns: vec![MethodPattern::Exact("fmt".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Debug".into(),
aliases: vec!["std::fmt::Debug".into(), "core::fmt::Debug".into()],
method_patterns: vec![MethodPattern::Exact("fmt".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "From".into(),
aliases: vec!["std::convert::From".into(), "core::convert::From".into()],
method_patterns: vec![MethodPattern::Exact("from".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Into".into(),
aliases: vec!["std::convert::Into".into(), "core::convert::Into".into()],
method_patterns: vec![MethodPattern::Exact("into".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "TryFrom".into(),
aliases: vec![
"std::convert::TryFrom".into(),
"core::convert::TryFrom".into(),
],
method_patterns: vec![MethodPattern::Exact("try_from".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "TryInto".into(),
aliases: vec![
"std::convert::TryInto".into(),
"core::convert::TryInto".into(),
],
method_patterns: vec![MethodPattern::Exact("try_into".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "AsRef".into(),
aliases: vec!["std::convert::AsRef".into(), "core::convert::AsRef".into()],
method_patterns: vec![MethodPattern::Exact("as_ref".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "AsMut".into(),
aliases: vec!["std::convert::AsMut".into(), "core::convert::AsMut".into()],
method_patterns: vec![MethodPattern::Exact("as_mut".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Deref".into(),
aliases: vec!["std::ops::Deref".into(), "core::ops::Deref".into()],
method_patterns: vec![MethodPattern::Exact("deref".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "DerefMut".into(),
aliases: vec!["std::ops::DerefMut".into(), "core::ops::DerefMut".into()],
method_patterns: vec![MethodPattern::Exact("deref_mut".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Borrow".into(),
aliases: vec!["std::borrow::Borrow".into(), "core::borrow::Borrow".into()],
method_patterns: vec![MethodPattern::Exact("borrow".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "BorrowMut".into(),
aliases: vec![
"std::borrow::BorrowMut".into(),
"core::borrow::BorrowMut".into(),
],
method_patterns: vec![MethodPattern::Exact("borrow_mut".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "Index".into(),
aliases: vec!["std::ops::Index".into(), "core::ops::Index".into()],
method_patterns: vec![MethodPattern::Exact("index".into())],
category: TraitCategory::Standard,
});
registry.add(KnownTrait {
path: "IndexMut".into(),
aliases: vec!["std::ops::IndexMut".into(), "core::ops::IndexMut".into()],
method_patterns: vec![MethodPattern::Exact("index_mut".into())],
category: TraitCategory::Standard,
});
registry
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MethodOrigin {
TraitMandated {
trait_name: String,
category: TraitCategory,
},
SelfChosen,
}
impl MethodOrigin {
#[must_use]
pub fn weight(&self) -> f64 {
match self {
Self::TraitMandated { category, .. } => category.default_weight(),
Self::SelfChosen => 1.0,
}
}
#[must_use]
pub fn is_trait_mandated(&self) -> bool {
matches!(self, Self::TraitMandated { .. })
}
#[must_use]
pub fn is_extractable(&self) -> bool {
matches!(self, Self::SelfChosen)
}
}
impl Default for MethodOrigin {
fn default() -> Self {
Self::SelfChosen
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifiedMethod {
pub name: String,
pub origin: MethodOrigin,
pub weight: f64,
}
impl ClassifiedMethod {
#[must_use]
pub fn new(name: String, origin: MethodOrigin) -> Self {
let weight = origin.weight();
Self {
name,
origin,
weight,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraitImplInfo {
pub trait_path: String,
pub method_names: Vec<String>,
pub category: TraitCategory,
}
impl TraitImplInfo {
#[must_use]
pub fn new(trait_path: String, method_names: Vec<String>, category: TraitCategory) -> Self {
Self {
trait_path,
method_names,
category,
}
}
#[must_use]
pub fn contains_method(&self, method_name: &str) -> bool {
self.method_names.iter().any(|m| m == method_name)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TraitMethodSummary {
pub mandated_count: usize,
pub by_trait: HashMap<String, usize>,
pub weighted_count: f64,
pub extractable_count: usize,
pub total_methods: usize,
}
impl TraitMethodSummary {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_classifications(methods: &[ClassifiedMethod]) -> Self {
let mut by_trait: HashMap<String, usize> = HashMap::new();
let mut mandated_count = 0;
let mut extractable_count = 0;
let mut weighted_count = 0.0;
for method in methods {
weighted_count += method.weight;
match &method.origin {
MethodOrigin::TraitMandated { trait_name, .. } => {
mandated_count += 1;
*by_trait.entry(trait_name.clone()).or_insert(0) += 1;
}
MethodOrigin::SelfChosen => {
extractable_count += 1;
}
}
}
Self {
mandated_count,
by_trait,
weighted_count,
extractable_count,
total_methods: methods.len(),
}
}
#[must_use]
pub fn extractable_ratio(&self) -> f64 {
if self.total_methods == 0 {
return 1.0;
}
self.extractable_count as f64 / self.total_methods as f64
}
#[must_use]
pub fn mandated_ratio(&self) -> f64 {
if self.total_methods == 0 {
return 0.0;
}
self.mandated_count as f64 / self.total_methods as f64
}
#[must_use]
pub fn is_trait_dominated(&self) -> bool {
self.mandated_ratio() > 0.5
}
#[must_use]
pub fn format_trait_breakdown(&self) -> String {
if self.by_trait.is_empty() {
return String::new();
}
let mut items: Vec<_> = self.by_trait.iter().collect();
items.sort_by(|a, b| b.1.cmp(a.1));
items
.iter()
.map(|(trait_name, count)| format!("{} ({})", trait_name, count))
.collect::<Vec<_>>()
.join(", ")
}
}
impl std::fmt::Display for TraitMethodSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} ({} trait-mandated, {} extractable)",
self.total_methods, self.mandated_count, self.extractable_count
)
}
}
#[must_use]
pub fn classify_method_origin(
method_name: &str,
trait_impls: &[TraitImplInfo],
registry: &KnownTraitRegistry,
) -> MethodOrigin {
for impl_info in trait_impls {
if impl_info.contains_method(method_name) {
let category = registry.categorize_trait(&impl_info.trait_path);
return MethodOrigin::TraitMandated {
trait_name: impl_info.trait_path.clone(),
category,
};
}
}
MethodOrigin::SelfChosen
}
#[must_use]
pub fn classify_all_methods(
method_names: &[String],
trait_impls: &[TraitImplInfo],
registry: &KnownTraitRegistry,
) -> Vec<ClassifiedMethod> {
method_names
.iter()
.map(|name| {
let origin = classify_method_origin(name, trait_impls, registry);
ClassifiedMethod::new(name.clone(), origin)
})
.collect()
}
#[must_use]
pub fn calculate_trait_weighted_count(methods: &[ClassifiedMethod]) -> f64 {
methods.iter().map(|m| m.weight).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_method_pattern_exact() {
let pattern = MethodPattern::Exact("next".into());
assert!(pattern.matches("next"));
assert!(!pattern.matches("next_back"));
assert!(!pattern.matches(""));
}
#[test]
fn test_method_pattern_prefix() {
let pattern = MethodPattern::Prefix("visit_".into());
assert!(pattern.matches("visit_expr"));
assert!(pattern.matches("visit_stmt"));
assert!(!pattern.matches("fold_expr"));
assert!(!pattern.matches("visit"));
}
#[test]
fn test_method_pattern_suffix() {
let pattern = MethodPattern::Suffix("_mut".into());
assert!(pattern.matches("get_mut"));
assert!(pattern.matches("index_mut"));
assert!(!pattern.matches("mut_get"));
}
#[test]
fn test_known_trait_matches() {
let trait_info = KnownTrait {
path: "syn::Visit".into(),
aliases: vec!["Visit".into()],
method_patterns: vec![MethodPattern::Prefix("visit_".into())],
category: TraitCategory::Visitor,
};
assert!(trait_info.matches_trait_name("syn::Visit"));
assert!(trait_info.matches_trait_name("Visit"));
assert!(trait_info.matches_trait_name("Visit"));
assert!(!trait_info.matches_trait_name("Iterator"));
}
#[test]
fn test_registry_default() {
let registry = KnownTraitRegistry::default();
assert!(!registry.is_empty());
assert!(registry.find("syn::Visit").is_some());
assert!(registry.find("Iterator").is_some());
assert!(registry.find("Default").is_some());
assert!(registry.find("Clone").is_some());
}
#[test]
fn test_registry_categorize() {
let registry = KnownTraitRegistry::default();
assert_eq!(
registry.categorize_trait("syn::Visit"),
TraitCategory::Visitor
);
assert_eq!(
registry.categorize_trait("Iterator"),
TraitCategory::Iterator
);
assert_eq!(
registry.categorize_trait("Default"),
TraitCategory::Standard
);
assert_eq!(
registry.categorize_trait("UnknownTrait"),
TraitCategory::Custom
);
}
#[test]
fn test_trait_category_weights() {
assert!((TraitCategory::Visitor.default_weight() - 0.1).abs() < f64::EPSILON);
assert!((TraitCategory::Serialization.default_weight() - 0.1).abs() < f64::EPSILON);
assert!((TraitCategory::Iterator.default_weight() - 0.3).abs() < f64::EPSILON);
assert!((TraitCategory::Standard.default_weight() - 0.2).abs() < f64::EPSILON);
assert!((TraitCategory::Custom.default_weight() - 0.4).abs() < f64::EPSILON);
}
#[test]
fn test_method_origin_weights() {
let mandated = MethodOrigin::TraitMandated {
trait_name: "syn::Visit".into(),
category: TraitCategory::Visitor,
};
assert!((mandated.weight() - 0.1).abs() < f64::EPSILON);
let self_chosen = MethodOrigin::SelfChosen;
assert!((self_chosen.weight() - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_classify_method_origin() {
let registry = KnownTraitRegistry::default();
let trait_impls = vec![TraitImplInfo::new(
"syn::Visit".into(),
vec!["visit_expr".into(), "visit_stmt".into()],
TraitCategory::Visitor,
)];
let origin = classify_method_origin("visit_expr", &trait_impls, ®istry);
assert!(origin.is_trait_mandated());
let origin = classify_method_origin("process_data", &trait_impls, ®istry);
assert!(origin.is_extractable());
}
#[test]
fn test_trait_method_summary() {
let registry = KnownTraitRegistry::default();
let trait_impls = vec![TraitImplInfo::new(
"syn::Visit".into(),
vec!["visit_expr".into(), "visit_stmt".into()],
TraitCategory::Visitor,
)];
let method_names = vec![
"visit_expr".into(),
"visit_stmt".into(),
"process_data".into(),
"helper".into(),
];
let classified = classify_all_methods(&method_names, &trait_impls, ®istry);
let summary = TraitMethodSummary::from_classifications(&classified);
assert_eq!(summary.total_methods, 4);
assert_eq!(summary.mandated_count, 2);
assert_eq!(summary.extractable_count, 2);
assert_eq!(*summary.by_trait.get("syn::Visit").unwrap(), 2);
assert!((summary.weighted_count - 2.2).abs() < 0.01);
}
#[test]
fn test_trait_method_summary_display() {
let summary = TraitMethodSummary {
mandated_count: 14,
by_trait: [("syn::Visit".into(), 14)].into_iter().collect(),
weighted_count: 15.8,
extractable_count: 18,
total_methods: 32,
};
assert_eq!(
summary.to_string(),
"32 (14 trait-mandated, 18 extractable)"
);
}
}