use std::collections::HashMap;
use syn::{visit::Visit, File, Item, ItemImpl};
#[derive(Debug, Clone)]
pub struct TraitImplInfo {
pub trait_name: String,
pub type_name: String,
pub line_count: usize,
pub start_line: usize,
pub end_line: usize,
pub is_unit_struct: bool,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RegistryPattern {
pub trait_name: String,
pub impl_count: usize,
pub avg_impl_size: f64,
pub impl_size_stddev: f64,
pub total_lines: usize,
pub unit_struct_ratio: f64,
pub has_static_registry: bool,
pub trait_impl_coverage: f64,
}
pub struct RegistryPatternDetector {
pub min_impl_count: usize,
pub max_avg_impl_size: usize,
pub min_coverage: f64,
}
impl Default for RegistryPatternDetector {
fn default() -> Self {
Self {
min_impl_count: 20,
max_avg_impl_size: 15,
min_coverage: 0.80,
}
}
}
impl RegistryPatternDetector {
pub fn new() -> Self {
Self::default()
}
pub fn detect(&self, file: &File, file_content: &str) -> Option<RegistryPattern> {
let trait_impls = extract_trait_impls(file, file_content);
let total_lines = file_content.lines().count();
let mut impls_by_trait: HashMap<String, Vec<&TraitImplInfo>> = HashMap::new();
for impl_info in &trait_impls {
impls_by_trait
.entry(impl_info.trait_name.clone())
.or_default()
.push(impl_info);
}
let (dominant_trait, dominant_impls) =
impls_by_trait.iter().max_by_key(|(_, impls)| impls.len())?;
let impl_count = dominant_impls.len();
if impl_count < self.min_impl_count {
return None;
}
let total_impl_lines: usize = dominant_impls.iter().map(|i| i.line_count).sum();
let avg_impl_size = total_impl_lines as f64 / impl_count as f64;
if avg_impl_size >= self.max_avg_impl_size as f64 {
return None;
}
let variance: f64 = dominant_impls
.iter()
.map(|i| {
let diff = i.line_count as f64 - avg_impl_size;
diff * diff
})
.sum::<f64>()
/ impl_count as f64;
let impl_size_stddev = variance.sqrt();
let unit_struct_count = dominant_impls.iter().filter(|i| i.is_unit_struct).count();
let unit_struct_ratio = unit_struct_count as f64 / impl_count as f64;
let trait_impl_coverage = total_impl_lines as f64 / total_lines as f64;
if trait_impl_coverage < self.min_coverage {
return None;
}
let has_static_registry = file_content.contains("const") && file_content.contains("&[");
Some(RegistryPattern {
trait_name: dominant_trait.clone(),
impl_count,
avg_impl_size,
impl_size_stddev,
total_lines,
unit_struct_ratio,
has_static_registry,
trait_impl_coverage,
})
}
pub fn confidence(&self, pattern: &RegistryPattern) -> f64 {
let mut confidence = 0.0;
confidence += (pattern.impl_count as f64 / 100.0).min(0.3);
if pattern.avg_impl_size < 10.0 {
confidence += 0.3;
} else if pattern.avg_impl_size < 15.0 {
confidence += 0.2;
}
if pattern.trait_impl_coverage > 0.9 {
confidence += 0.2;
} else if pattern.trait_impl_coverage > 0.8 {
confidence += 0.1;
}
if pattern.unit_struct_ratio > 0.8 {
confidence += 0.15;
} else if pattern.unit_struct_ratio > 0.5 {
confidence += 0.1;
}
if pattern.has_static_registry {
confidence += 0.05;
}
confidence.min(1.0)
}
}
fn extract_trait_impls(file: &File, file_content: &str) -> Vec<TraitImplInfo> {
let mut visitor = TraitImplVisitor {
impls: Vec::new(),
unit_structs: std::collections::HashSet::new(),
file_content,
};
visitor.visit_file(file);
visitor.impls
}
struct TraitImplVisitor<'a> {
impls: Vec<TraitImplInfo>,
unit_structs: std::collections::HashSet<String>,
file_content: &'a str,
}
impl<'a, 'ast> Visit<'ast> for TraitImplVisitor<'a> {
fn visit_item(&mut self, item: &'ast Item) {
match item {
Item::Struct(item_struct) => {
if matches!(item_struct.fields, syn::Fields::Unit) {
self.unit_structs.insert(item_struct.ident.to_string());
}
}
Item::Impl(item_impl) => {
if let Some(impl_info) = extract_impl_info(item_impl, self.file_content) {
let is_unit_struct = self.unit_structs.contains(&impl_info.type_name);
self.impls.push(TraitImplInfo {
is_unit_struct,
..impl_info
});
}
}
_ => {}
}
syn::visit::visit_item(self, item);
}
}
fn extract_impl_info(item_impl: &ItemImpl, file_content: &str) -> Option<TraitImplInfo> {
use syn::spanned::Spanned;
let (_, trait_path, _) = item_impl.trait_.as_ref()?;
let trait_name = trait_path.segments.last()?.ident.to_string();
let type_name = match &*item_impl.self_ty {
syn::Type::Path(type_path) => type_path.path.segments.last()?.ident.to_string(),
_ => return None,
};
let span = item_impl.span();
let start_line = span.start().line;
let end_line = span.end().line;
let line_count = count_lines_in_span(file_content, start_line, end_line);
Some(TraitImplInfo {
trait_name,
type_name,
line_count,
start_line,
end_line,
is_unit_struct: false, })
}
fn count_lines_in_span(content: &str, start_line: usize, end_line: usize) -> usize {
content
.lines()
.enumerate()
.skip(start_line.saturating_sub(1))
.take(end_line.saturating_sub(start_line) + 1)
.filter(|(_, line)| {
let trimmed = line.trim();
!trimmed.is_empty() && !trimmed.starts_with("//")
})
.count()
}
pub fn adjust_registry_score(base_score: f64, pattern: &RegistryPattern) -> f64 {
let reduction_factor = if pattern.avg_impl_size < 10.0 {
0.2 } else if pattern.avg_impl_size < 15.0 {
0.3 } else {
0.5 };
base_score * reduction_factor
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_rust_code(code: &str) -> File {
syn::parse_str(code).expect("Failed to parse Rust code")
}
#[test]
fn test_detect_registry_pattern_basic() {
let code = r#"
struct Flag1;
struct Flag2;
struct Flag3;
trait Flag {
fn name(&self) -> &str;
}
impl Flag for Flag1 { fn name(&self) -> &str { "flag1" } }
impl Flag for Flag2 { fn name(&self) -> &str { "flag2" } }
impl Flag for Flag3 { fn name(&self) -> &str { "flag3" } }
"#;
let file = parse_rust_code(code);
let detector = RegistryPatternDetector {
min_impl_count: 3,
max_avg_impl_size: 15,
min_coverage: 0.1,
};
let pattern = detector.detect(&file, code);
assert!(pattern.is_some());
let pattern = pattern.unwrap();
assert_eq!(pattern.trait_name, "Flag");
assert_eq!(pattern.impl_count, 3);
}
#[test]
fn test_registry_score_reduction() {
let pattern = RegistryPattern {
trait_name: "Flag".into(),
impl_count: 150,
avg_impl_size: 8.0,
total_lines: 7775,
unit_struct_ratio: 0.95,
has_static_registry: true,
trait_impl_coverage: 0.90,
impl_size_stddev: 2.5,
};
let base_score = 1000.0;
let adjusted = adjust_registry_score(base_score, &pattern);
assert!((adjusted - 200.0).abs() < 1.0);
}
#[test]
fn test_not_registry_large_impls() {
let code = r#"
trait Processor {
fn process(&self, data: &str) -> String;
fn validate(&self, data: &str) -> bool;
fn transform(&self, data: &str) -> String;
}
impl Processor for TypeA {
fn process(&self, data: &str) -> String {
// Many lines of complex logic
let mut result = String::new();
for line in data.lines() {
result.push_str(&line.to_uppercase());
result.push('\n');
}
result
}
fn validate(&self, data: &str) -> bool { true }
fn transform(&self, data: &str) -> String { data.to_string() }
}
"#;
let file = parse_rust_code(code);
let detector = RegistryPatternDetector::default();
let pattern = detector.detect(&file, code);
assert!(
pattern.is_none(),
"Large implementations should not be registry"
);
}
#[test]
fn test_confidence_calculation() {
let detector = RegistryPatternDetector::default();
let high_confidence_pattern = RegistryPattern {
trait_name: "Flag".into(),
impl_count: 100,
avg_impl_size: 8.0,
total_lines: 1000,
unit_struct_ratio: 0.9,
has_static_registry: true,
trait_impl_coverage: 0.95,
impl_size_stddev: 2.0,
};
let confidence = detector.confidence(&high_confidence_pattern);
assert!(
confidence > 0.8,
"High confidence pattern should score > 0.8"
);
let low_confidence_pattern = RegistryPattern {
trait_name: "Trait".into(),
impl_count: 20,
avg_impl_size: 14.0,
total_lines: 500,
unit_struct_ratio: 0.3,
has_static_registry: false,
trait_impl_coverage: 0.80,
impl_size_stddev: 5.0,
};
let confidence = detector.confidence(&low_confidence_pattern);
assert!(
confidence < 0.6,
"Low confidence pattern should score < 0.6"
);
}
#[test]
fn test_unit_struct_detection() {
let code = r#"
struct UnitStruct;
struct RegularStruct { field: i32 }
trait Trait {
fn method(&self);
}
impl Trait for UnitStruct {
fn method(&self) {}
}
impl Trait for RegularStruct {
fn method(&self) {}
}
"#;
let file = parse_rust_code(code);
let impls = extract_trait_impls(&file, code);
let unit_impl = impls.iter().find(|i| i.type_name == "UnitStruct");
let regular_impl = impls.iter().find(|i| i.type_name == "RegularStruct");
assert!(unit_impl.is_some());
assert!(unit_impl.unwrap().is_unit_struct);
assert!(regular_impl.is_some());
assert!(!regular_impl.unwrap().is_unit_struct);
}
}