use super::rules::{PatternMetrics, PatternValidator, ValidationConfig};
use crate::pattern::Pattern;
use std::collections::HashMap;
use tracing::{debug, instrument};
use uuid::Uuid;
impl PatternValidator {
#[instrument(skip(self, ground_truth, extracted))]
pub fn calculate_metrics(
&self,
ground_truth: &[Pattern],
extracted: &[Pattern],
) -> PatternMetrics {
let mut tp = 0; let mut fp = 0;
let gt_map = self.build_pattern_map(ground_truth);
let mut matched_gt: std::collections::HashSet<Uuid> = std::collections::HashSet::new();
for extracted_pattern in extracted {
if let Some(gt_pattern) = self.find_matching_pattern(extracted_pattern, >_map) {
if self.patterns_match(extracted_pattern, gt_pattern) {
tp += 1;
matched_gt.insert(gt_pattern.id());
} else {
fp += 1;
}
} else {
fp += 1; }
}
let fn_ = ground_truth.len() - matched_gt.len();
let tn = 0;
debug!(
tp = tp,
fp = fp,
fn_ = fn_,
tn = tn,
"Calculated pattern validation metrics"
);
PatternMetrics::from_counts(tp, fp, fn_, tn)
}
fn build_pattern_map<'a>(&self, patterns: &'a [Pattern]) -> HashMap<String, Vec<&'a Pattern>> {
let mut map: HashMap<String, Vec<&'a Pattern>> = HashMap::new();
for pattern in patterns {
let key = self.pattern_type_key(pattern);
map.entry(key).or_default().push(pattern);
}
map
}
fn pattern_type_key(&self, pattern: &Pattern) -> String {
match pattern {
Pattern::ToolSequence { tools, .. } => format!("tool_seq_{}", tools.join("_")),
Pattern::DecisionPoint { condition, .. } => format!("decision_{condition}"),
Pattern::ErrorRecovery { error_type, .. } => format!("error_{error_type}"),
Pattern::ContextPattern {
context_features, ..
} => format!("context_{}", context_features.join("_")),
}
}
fn find_matching_pattern<'a>(
&self,
extracted: &Pattern,
gt_map: &'a HashMap<String, Vec<&'a Pattern>>,
) -> Option<&'a Pattern> {
let key = self.pattern_type_key(extracted);
if let Some(candidates) = gt_map.get(&key) {
for candidate in candidates {
if self.patterns_match(extracted, candidate) {
return Some(candidate);
}
}
}
None
}
fn patterns_match(&self, p1: &Pattern, p2: &Pattern) -> bool {
match (p1, p2) {
(Pattern::ToolSequence { tools: t1, .. }, Pattern::ToolSequence { tools: t2, .. }) => {
self.sequence_similarity(t1, t2) >= self.config.similarity_threshold
}
(
Pattern::DecisionPoint {
condition: c1,
action: a1,
..
},
Pattern::DecisionPoint {
condition: c2,
action: a2,
..
},
) => {
self.string_similarity(c1, c2) >= self.config.similarity_threshold
&& self.string_similarity(a1, a2) >= self.config.similarity_threshold
}
(
Pattern::ErrorRecovery {
error_type: e1,
recovery_steps: r1,
..
},
Pattern::ErrorRecovery {
error_type: e2,
recovery_steps: r2,
..
},
) => {
self.string_similarity(e1, e2) >= self.config.similarity_threshold
&& self.sequence_similarity(r1, r2) >= self.config.similarity_threshold
}
(
Pattern::ContextPattern {
context_features: f1,
..
},
Pattern::ContextPattern {
context_features: f2,
..
},
) => {
self.sequence_similarity(f1, f2) >= self.config.similarity_threshold
}
_ => false, }
}
fn sequence_similarity(&self, seq1: &[String], seq2: &[String]) -> f32 {
if seq1.is_empty() && seq2.is_empty() {
return 1.0;
}
if seq1.is_empty() || seq2.is_empty() {
return 0.0;
}
let seq1_set: std::collections::HashSet<_> = seq1.iter().collect();
let seq2_set: std::collections::HashSet<_> = seq2.iter().collect();
let intersection = seq1_set.intersection(&seq2_set).count();
let union = seq1_set.union(&seq2_set).count();
if union > 0 {
intersection as f32 / union as f32
} else {
0.0
}
}
fn string_similarity(&self, s1: &str, s2: &str) -> f32 {
if s1 == s2 {
return 1.0;
}
let s1_lower = s1.to_lowercase();
let s2_lower = s2.to_lowercase();
let words1: std::collections::HashSet<_> = s1_lower.split_whitespace().collect();
let words2: std::collections::HashSet<_> = s2_lower.split_whitespace().collect();
let intersection = words1.intersection(&words2).count();
let union = words1.union(&words2).count();
if union > 0 {
intersection as f32 / union as f32
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ComplexityLevel, TaskContext};
use chrono::Duration;
fn create_test_context() -> TaskContext {
TaskContext {
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Moderate,
domain: "testing".to_string(),
tags: vec!["async".to_string()],
}
}
#[test]
fn test_calculate_metrics_perfect_match() {
let config = ValidationConfig::default();
let validator = PatternValidator::new(config);
let ground_truth = vec![
Pattern::ToolSequence {
id: Uuid::new_v4(),
tools: vec!["tool1".to_string(), "tool2".to_string()],
context: create_test_context(),
success_rate: 0.9,
avg_latency: Duration::milliseconds(100),
occurrence_count: 5,
effectiveness: crate::pattern::PatternEffectiveness::new(),
},
Pattern::ErrorRecovery {
id: Uuid::new_v4(),
error_type: "timeout".to_string(),
recovery_steps: vec!["retry".to_string()],
success_rate: 0.8,
context: create_test_context(),
effectiveness: crate::pattern::PatternEffectiveness::new(),
},
];
let extracted = ground_truth.clone();
let metrics = validator.calculate_metrics(&ground_truth, &extracted);
assert_eq!(metrics.true_positives, 2);
assert_eq!(metrics.false_positives, 0);
assert_eq!(metrics.false_negatives, 0);
assert_eq!(metrics.precision, 1.0);
assert_eq!(metrics.recall, 1.0);
assert_eq!(metrics.f1_score, 1.0);
}
#[test]
fn test_calculate_metrics_partial_match() {
let config = ValidationConfig::default();
let validator = PatternValidator::new(config);
let ground_truth = vec![
Pattern::ToolSequence {
id: Uuid::new_v4(),
tools: vec!["tool1".to_string(), "tool2".to_string()],
context: create_test_context(),
success_rate: 0.9,
avg_latency: Duration::milliseconds(100),
occurrence_count: 5,
effectiveness: crate::pattern::PatternEffectiveness::new(),
},
Pattern::ErrorRecovery {
id: Uuid::new_v4(),
error_type: "timeout".to_string(),
recovery_steps: vec!["retry".to_string()],
success_rate: 0.8,
context: create_test_context(),
effectiveness: crate::pattern::PatternEffectiveness::new(),
},
];
let extracted = vec![Pattern::ToolSequence {
id: Uuid::new_v4(),
tools: vec!["tool1".to_string(), "tool2".to_string()],
context: create_test_context(),
success_rate: 0.9,
avg_latency: Duration::milliseconds(100),
occurrence_count: 5,
effectiveness: crate::pattern::PatternEffectiveness::new(),
}];
let metrics = validator.calculate_metrics(&ground_truth, &extracted);
assert_eq!(metrics.true_positives, 1);
assert_eq!(metrics.false_positives, 0);
assert_eq!(metrics.false_negatives, 1);
assert_eq!(metrics.precision, 1.0); assert_eq!(metrics.recall, 0.5); assert!((metrics.f1_score - 0.666).abs() < 0.01);
}
#[test]
fn test_sequence_similarity() {
let config = ValidationConfig::default();
let validator = PatternValidator::new(config);
let seq1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let seq2 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert_eq!(validator.sequence_similarity(&seq1, &seq2), 1.0);
let seq3 = vec!["a".to_string(), "b".to_string()];
let sim = validator.sequence_similarity(&seq1, &seq3);
assert!(sim > 0.5 && sim < 1.0);
let seq4 = vec!["x".to_string(), "y".to_string()];
let sim = validator.sequence_similarity(&seq1, &seq4);
assert!(sim < 0.5);
}
#[test]
fn test_string_similarity() {
let config = ValidationConfig::default();
let validator = PatternValidator::new(config);
assert_eq!(
validator.string_similarity("hello world", "hello world"),
1.0
);
let sim = validator.string_similarity("hello world", "hello there");
assert!(sim > 0.3 && sim < 1.0);
let sim = validator.string_similarity("hello", "goodbye");
assert_eq!(sim, 0.0);
}
}