use std::collections::HashMap;
use syn::{ExprField, Item, Member};
#[derive(Debug, Clone)]
pub struct StateFieldDetector {
keywords: StateKeywordDict,
type_cache: HashMap<String, TypeInfo>,
usage_tracker: UsageTracker,
config: StateDetectionConfig,
}
#[derive(Debug, Clone)]
pub struct StateKeywordDict {
pub field_keywords: Vec<String>,
pub path_keywords: Vec<String>,
pub prefix_patterns: Vec<String>,
pub suffix_patterns: Vec<String>,
pub compound_patterns: Vec<String>,
}
impl Default for StateKeywordDict {
fn default() -> Self {
Self {
field_keywords: vec![
"state",
"mode",
"status",
"phase",
"stage",
"desired",
"current",
"target",
"actual",
"fsm",
"transition",
"automaton",
"machine",
"lifecycle",
"step",
"iteration",
"round",
"flow",
"control",
"sequence",
"type",
"kind",
"variant",
"form",
"connection",
"protocol",
"handshake",
"request",
"response",
"reply",
"ctx",
"context",
"env",
"environment",
"operation",
"action",
]
.into_iter()
.map(String::from)
.collect(),
path_keywords: vec![
"state",
"mode",
"status",
"phase",
"fsm",
"transition",
"stage",
"step",
"ctx",
"context",
"kind",
"type",
]
.into_iter()
.map(String::from)
.collect(),
prefix_patterns: vec![
"current_",
"next_",
"prev_",
"previous_",
"target_",
"desired_",
"actual_",
"expected_",
"old_",
"new_",
"initial_",
"final_",
]
.into_iter()
.map(String::from)
.collect(),
suffix_patterns: vec![
"_state", "_mode", "_status", "_phase", "_stage", "_step", "_type", "_kind",
"_variant", "_flag", "_control",
]
.into_iter()
.map(String::from)
.collect(),
compound_patterns: vec![
"flow_control",
"state_machine",
"fsm_state",
"request_type",
"response_kind",
"connection_state",
"protocol_phase",
"processing_stage",
"lifecycle_step",
"current_operation",
"next_operation",
"current_action",
"next_action",
]
.into_iter()
.map(String::from)
.collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct TypeInfo {
pub is_enum: bool,
pub variant_count: usize,
pub variants: Vec<String>,
pub is_wrapped: bool,
pub inner_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct UsageTracker {
match_counts: HashMap<String, usize>,
comparison_counts: HashMap<String, usize>,
occurrence_counts: HashMap<String, usize>,
}
impl UsageTracker {
pub fn new() -> Self {
Self {
match_counts: HashMap::new(),
comparison_counts: HashMap::new(),
occurrence_counts: HashMap::new(),
}
}
pub fn record_match(&mut self, field: &str) {
*self.match_counts.entry(field.to_string()).or_insert(0) += 1;
*self.occurrence_counts.entry(field.to_string()).or_insert(0) += 1;
}
pub fn record_comparison(&mut self, field: &str) {
*self.comparison_counts.entry(field.to_string()).or_insert(0) += 1;
*self.occurrence_counts.entry(field.to_string()).or_insert(0) += 1;
}
pub fn get_frequency_score(&self, field: &str) -> f64 {
let matches = self.match_counts.get(field).copied().unwrap_or(0);
let comparisons = self.comparison_counts.get(field).copied().unwrap_or(0);
let total = matches + comparisons;
match total {
0 => 0.0,
1..=2 => 0.1,
3..=5 => 0.3,
_ => 0.4,
}
}
}
impl Default for UsageTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StateFieldDetection {
pub field_name: String,
pub confidence: f64,
pub classification: ConfidenceClass,
pub breakdown: ConfidenceBreakdown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfidenceClass {
High, Medium, Low, }
#[derive(Debug, Clone)]
pub struct ConfidenceBreakdown {
pub keyword_score: f64,
pub type_score: f64,
pub pattern_score: f64,
pub frequency_score: f64,
pub explanation: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StateDetectionConfig {
#[serde(default = "default_use_type_analysis")]
pub use_type_analysis: bool,
#[serde(default = "default_use_frequency_analysis")]
pub use_frequency_analysis: bool,
#[serde(default = "default_use_pattern_recognition")]
pub use_pattern_recognition: bool,
#[serde(default = "default_min_enum_variants")]
pub min_enum_variants: usize,
#[serde(default)]
pub custom_keywords: Vec<String>,
#[serde(default)]
pub custom_patterns: Vec<String>,
}
fn default_use_type_analysis() -> bool {
true
}
fn default_use_frequency_analysis() -> bool {
true
}
fn default_use_pattern_recognition() -> bool {
true
}
fn default_min_enum_variants() -> usize {
3
}
impl Default for StateDetectionConfig {
fn default() -> Self {
Self {
use_type_analysis: true,
use_frequency_analysis: true,
use_pattern_recognition: true,
min_enum_variants: 3,
custom_keywords: Vec::new(),
custom_patterns: Vec::new(),
}
}
}
impl StateFieldDetector {
pub fn new(config: StateDetectionConfig) -> Self {
let mut keywords = StateKeywordDict::default();
keywords
.field_keywords
.extend(config.custom_keywords.clone());
keywords
.compound_patterns
.extend(config.custom_patterns.clone());
Self {
keywords,
type_cache: HashMap::new(),
usage_tracker: UsageTracker::new(),
config,
}
}
pub fn detect_state_field(&self, field_expr: &ExprField) -> StateFieldDetection {
let field_name = match &field_expr.member {
Member::Named(ident) => ident.to_string(),
Member::Unnamed(_) => return self.low_confidence_result("unnamed"),
};
let mut breakdown = ConfidenceBreakdown {
keyword_score: 0.0,
type_score: 0.0,
pattern_score: 0.0,
frequency_score: 0.0,
explanation: String::new(),
};
let normalized = field_name.to_lowercase();
let is_compound = self
.keywords
.compound_patterns
.iter()
.any(|p| normalized == p.to_lowercase());
if is_compound {
breakdown.keyword_score = 0.5; breakdown.explanation.push_str("compound pattern match; ");
} else if self.matches_keyword(&field_name) {
breakdown.keyword_score = 0.3;
breakdown.explanation.push_str("keyword match; ");
}
if self.config.use_type_analysis {
if let Some(type_info) = self.analyze_field_type(field_expr) {
if type_info.is_enum && type_info.variant_count >= self.config.min_enum_variants {
breakdown.type_score = 0.4;
breakdown
.explanation
.push_str(&format!("enum with {} variants; ", type_info.variant_count));
}
}
}
if self.config.use_pattern_recognition {
let pattern_score = self.analyze_semantic_patterns(&field_name);
breakdown.pattern_score = pattern_score;
if pattern_score > 0.0 {
breakdown.explanation.push_str("semantic pattern; ");
}
}
if self.config.use_frequency_analysis {
let freq_score = self.usage_tracker.get_frequency_score(&field_name);
breakdown.frequency_score = freq_score;
if freq_score > 0.0 {
breakdown.explanation.push_str(&format!(
"high usage frequency (score: {:.2}); ",
freq_score
));
}
}
let confidence = breakdown.keyword_score
+ breakdown.type_score
+ breakdown.pattern_score
+ breakdown.frequency_score;
let classification = match confidence {
c if c >= 0.7 => ConfidenceClass::High,
c if c >= 0.4 => ConfidenceClass::Medium,
_ => ConfidenceClass::Low,
};
StateFieldDetection {
field_name,
confidence,
classification,
breakdown,
}
}
fn matches_keyword(&self, field_name: &str) -> bool {
let normalized = field_name.to_lowercase();
if self
.keywords
.compound_patterns
.iter()
.any(|p| normalized == p.to_lowercase())
{
return true;
}
self.keywords
.field_keywords
.iter()
.any(|kw| normalized.contains(&kw.to_lowercase()))
}
fn analyze_semantic_patterns(&self, field_name: &str) -> f64 {
let normalized = field_name.to_lowercase();
let mut score: f64 = 0.0;
for prefix in &self.keywords.prefix_patterns {
if normalized.starts_with(&prefix.to_lowercase()) {
score += 0.25;
break;
}
}
for suffix in &self.keywords.suffix_patterns {
if normalized.ends_with(&suffix.to_lowercase()) {
score += 0.25;
break;
}
}
score.min(0.5) }
fn analyze_field_type(&self, field_expr: &ExprField) -> Option<TypeInfo> {
let field_name = match &field_expr.member {
Member::Named(ident) => ident.to_string(),
Member::Unnamed(_) => return None,
};
self.type_cache.get(&field_name).cloned()
}
pub fn build_type_database(&mut self, items: &[Item]) {
for item in items {
if let Item::Enum(enum_item) = item {
let type_name = enum_item.ident.to_string();
let variant_count = enum_item.variants.len();
let variants: Vec<String> = enum_item
.variants
.iter()
.map(|v| v.ident.to_string())
.collect();
self.type_cache.insert(
type_name.clone(),
TypeInfo {
is_enum: true,
variant_count,
variants,
is_wrapped: false,
inner_type: None,
},
);
}
}
}
fn low_confidence_result(&self, field_name: &str) -> StateFieldDetection {
StateFieldDetection {
field_name: field_name.to_string(),
confidence: 0.0,
classification: ConfidenceClass::Low,
breakdown: ConfidenceBreakdown {
keyword_score: 0.0,
type_score: 0.0,
pattern_score: 0.0,
frequency_score: 0.0,
explanation: "no indicators".to_string(),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_keyword_detection_original() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
assert!(detector.matches_keyword("state"));
assert!(detector.matches_keyword("mode"));
assert!(detector.matches_keyword("status"));
}
#[test]
fn test_keyword_detection_new() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
assert!(detector.matches_keyword("fsm"));
assert!(detector.matches_keyword("transition"));
assert!(detector.matches_keyword("lifecycle"));
assert!(detector.matches_keyword("ctx"));
}
#[test]
fn test_prefix_pattern() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
let score = detector.analyze_semantic_patterns("current_action");
assert!(score > 0.0);
let score = detector.analyze_semantic_patterns("next_step");
assert!(score > 0.0);
}
#[test]
fn test_suffix_pattern() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
let score = detector.analyze_semantic_patterns("connection_state");
assert!(score > 0.0);
let score = detector.analyze_semantic_patterns("request_type");
assert!(score > 0.0);
}
#[test]
fn test_compound_pattern() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
assert!(detector.matches_keyword("flow_control"));
assert!(detector.matches_keyword("state_machine"));
}
#[test]
fn test_confidence_aggregation() {
let detector = StateFieldDetector::new(StateDetectionConfig::default());
let field: ExprField = parse_quote! { self.fsm_state };
let detection = detector.detect_state_field(&field);
assert!(detection.confidence >= 0.3);
}
#[test]
fn test_enum_type_detection() {
let mut detector = StateFieldDetector::new(StateDetectionConfig::default());
let items: Vec<Item> = vec![parse_quote! {
enum ConnectionState {
Idle,
Connecting,
Connected,
Disconnected,
}
}];
detector.build_type_database(&items);
let type_info = detector.type_cache.get("ConnectionState").unwrap();
assert!(type_info.is_enum);
assert_eq!(type_info.variant_count, 4);
}
#[test]
fn test_usage_frequency() {
let mut tracker = UsageTracker::new();
tracker.record_match("flow_control");
tracker.record_match("flow_control");
tracker.record_match("flow_control");
let score = tracker.get_frequency_score("flow_control");
assert!(score >= 0.3); }
#[test]
fn test_custom_keywords() {
let config = StateDetectionConfig {
custom_keywords: vec!["workflow".to_string(), "scenario".to_string()],
..Default::default()
};
let detector = StateFieldDetector::new(config);
assert!(detector.matches_keyword("workflow"));
assert!(detector.matches_keyword("scenario"));
}
#[test]
fn test_false_negative_reduction() {
let baseline_config = StateDetectionConfig {
use_type_analysis: false,
use_frequency_analysis: false,
use_pattern_recognition: false,
min_enum_variants: 3,
custom_keywords: vec![],
custom_patterns: vec![],
};
let baseline_detector = StateFieldDetector::new(baseline_config);
let enhanced_detector = StateFieldDetector::new(StateDetectionConfig::default());
let non_standard_state_fields = vec![
parse_quote! { self.current_action }, parse_quote! { self.next_step }, parse_quote! { self.active_process }, parse_quote! { self.connection_type }, parse_quote! { self.operation_kind }, parse_quote! { self.request_stage }, parse_quote! { self.fsm_state }, parse_quote! { self.flow_control }, parse_quote! { self.lifecycle_phase }, parse_quote! { self.ctx }, parse_quote! { self.context }, parse_quote! { self.transition }, ];
let mut baseline_detected = 0;
let mut enhanced_detected = 0;
for field in &non_standard_state_fields {
let baseline_result = baseline_detector.detect_state_field(field);
let enhanced_result = enhanced_detector.detect_state_field(field);
if baseline_result.classification != ConfidenceClass::Low {
baseline_detected += 1;
}
if enhanced_result.classification != ConfidenceClass::Low {
enhanced_detected += 1;
}
}
let total = non_standard_state_fields.len();
let baseline_false_negatives = total - baseline_detected;
let enhanced_false_negatives = total - enhanced_detected;
let reduction_percentage = if baseline_false_negatives > 0 {
((baseline_false_negatives - enhanced_false_negatives) as f64
/ baseline_false_negatives as f64)
* 100.0
} else {
0.0
};
println!("False negative validation results:");
println!(
" Baseline detected: {}/{} ({:.1}%)",
baseline_detected,
total,
(baseline_detected as f64 / total as f64) * 100.0
);
println!(
" Enhanced detected: {}/{} ({:.1}%)",
enhanced_detected,
total,
(enhanced_detected as f64 / total as f64) * 100.0
);
println!(" Baseline false negatives: {}", baseline_false_negatives);
println!(" Enhanced false negatives: {}", enhanced_false_negatives);
println!(" Reduction: {:.1}%", reduction_percentage);
assert!(
reduction_percentage >= 40.0,
"False negative reduction ({:.1}%) does not meet spec 202 requirement (≥40%)",
reduction_percentage
);
assert!(
enhanced_detected as f64 / total as f64 >= 0.6,
"Enhanced detector should catch at least 60% of non-standard state fields"
);
}
#[test]
fn test_performance_overhead() {
use std::time::Instant;
let detector = StateFieldDetector::new(StateDetectionConfig::default());
let field: ExprField = parse_quote! { self.current_state };
let iterations = 1000;
let start = Instant::now();
for _ in 0..iterations {
let _ = detector.detect_state_field(&field);
}
let elapsed = start.elapsed();
let avg_time_us = elapsed.as_micros() / iterations;
println!("Performance: avg {:.2}μs per detection", avg_time_us);
assert!(
avg_time_us < 5000,
"Per-function overhead ({:.2}μs) exceeds spec 202 requirement (< 5000μs)",
avg_time_us
);
}
}