use crate::pattern::Pattern;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, instrument};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PatternMetrics {
pub precision: f32,
pub recall: f32,
pub f1_score: f32,
pub accuracy: f32,
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
pub true_negatives: usize,
}
impl PatternMetrics {
#[must_use]
pub fn from_counts(tp: usize, fp: usize, fn_: usize, tn: usize) -> Self {
let precision = if tp + fp > 0 {
tp as f32 / (tp + fp) as f32
} else {
0.0
};
let recall = if tp + fn_ > 0 {
tp as f32 / (tp + fn_) as f32
} else {
0.0
};
let f1_score = if precision + recall > 0.0 {
2.0 * (precision * recall) / (precision + recall)
} else {
0.0
};
let total = tp + fp + fn_ + tn;
let accuracy = if total > 0 {
(tp + tn) as f32 / total as f32
} else {
0.0
};
Self {
precision,
recall,
f1_score,
accuracy,
true_positives: tp,
false_positives: fp,
false_negatives: fn_,
true_negatives: tn,
}
}
#[must_use]
pub fn meets_target(&self, target_precision: f32, target_recall: f32) -> bool {
self.precision >= target_precision && self.recall >= target_recall
}
#[must_use]
pub fn quality_score(&self) -> f32 {
(self.f1_score * 0.6) + (self.precision * 0.25) + (self.recall * 0.15)
}
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub min_confidence: f32,
pub similarity_threshold: f32,
pub max_false_positive_rate: f32,
pub min_recall: f32,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
min_confidence: 0.7,
similarity_threshold: 0.8,
max_false_positive_rate: 0.2,
min_recall: 0.7,
}
}
}
pub struct PatternValidator {
config: ValidationConfig,
confidence_cache: HashMap<Uuid, f32>,
}
impl PatternValidator {
#[must_use]
pub fn new(config: ValidationConfig) -> Self {
Self {
config,
confidence_cache: HashMap::new(),
}
}
#[must_use]
pub fn validate_confidence(&self, pattern: &Pattern) -> bool {
let success_rate = pattern.success_rate();
success_rate >= self.config.min_confidence
}
pub fn track_effectiveness(&mut self, pattern_id: Uuid, used: bool, successful: bool) {
if used {
let current_confidence = self
.confidence_cache
.get(&pattern_id)
.copied()
.unwrap_or(0.5);
let new_confidence = if successful {
(current_confidence * 0.9) + (1.0 * 0.1)
} else {
(current_confidence * 0.9) + (0.0 * 0.1)
};
self.confidence_cache.insert(pattern_id, new_confidence);
debug!(
pattern_id = %pattern_id,
used = used,
successful = successful,
new_confidence = new_confidence,
"Tracked pattern effectiveness"
);
}
}
#[must_use]
pub fn get_confidence(&self, pattern_id: Uuid) -> Option<f32> {
self.confidence_cache.get(&pattern_id).copied()
}
#[must_use]
pub fn config(&self) -> &ValidationConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ComplexityLevel, TaskContext};
use chrono::Duration;
use uuid::Uuid;
fn create_test_context() -> TaskContext {
TaskContext {
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Moderate,
domain: "testing".to_string(),
tags: vec!["async".to_string()],
}
}
#[test]
fn test_pattern_metrics_calculation() {
let metrics = PatternMetrics::from_counts(5, 0, 0, 5);
assert_eq!(metrics.precision, 1.0);
assert_eq!(metrics.recall, 1.0);
assert_eq!(metrics.f1_score, 1.0);
assert_eq!(metrics.accuracy, 1.0);
let metrics = PatternMetrics::from_counts(3, 2, 1, 4);
assert_eq!(metrics.precision, 0.6); assert_eq!(metrics.recall, 0.75); assert_eq!(metrics.accuracy, 0.7);
assert!((metrics.f1_score - 0.666).abs() < 0.01);
}
#[test]
fn test_pattern_metrics_edge_cases() {
let metrics = PatternMetrics::from_counts(0, 0, 5, 5);
assert_eq!(metrics.precision, 0.0);
assert_eq!(metrics.recall, 0.0);
assert_eq!(metrics.f1_score, 0.0);
let metrics = PatternMetrics::from_counts(0, 5, 0, 5);
assert_eq!(metrics.precision, 0.0);
assert_eq!(metrics.recall, 0.0);
}
#[test]
fn test_validate_confidence() {
let config = ValidationConfig {
min_confidence: 0.7,
..Default::default()
};
let validator = PatternValidator::new(config);
let high_conf_pattern = Pattern::ToolSequence {
id: Uuid::new_v4(),
tools: vec!["tool1".to_string()],
context: create_test_context(),
success_rate: 0.9,
avg_latency: Duration::milliseconds(100),
occurrence_count: 10,
effectiveness: crate::pattern::PatternEffectiveness::new(),
};
let low_conf_pattern = Pattern::ToolSequence {
id: Uuid::new_v4(),
tools: vec!["tool2".to_string()],
context: create_test_context(),
success_rate: 0.5,
avg_latency: Duration::milliseconds(100),
occurrence_count: 3,
effectiveness: crate::pattern::PatternEffectiveness::new(),
};
assert!(validator.validate_confidence(&high_conf_pattern));
assert!(!validator.validate_confidence(&low_conf_pattern));
}
#[test]
fn test_track_effectiveness() {
let config = ValidationConfig::default();
let mut validator = PatternValidator::new(config);
let pattern_id = Uuid::new_v4();
validator.track_effectiveness(pattern_id, true, true);
let conf1 = validator.get_confidence(pattern_id).unwrap();
assert!(conf1 > 0.5);
validator.track_effectiveness(pattern_id, true, true);
let conf2 = validator.get_confidence(pattern_id).unwrap();
assert!(conf2 > conf1);
validator.track_effectiveness(pattern_id, true, false);
let conf3 = validator.get_confidence(pattern_id).unwrap();
assert!(conf3 < conf2);
}
#[test]
fn test_quality_score() {
let metrics = PatternMetrics::from_counts(8, 2, 1, 9);
let score = metrics.quality_score();
assert!(score > 0.0 && score <= 1.0);
let perfect = PatternMetrics::from_counts(10, 0, 0, 10);
assert!(perfect.quality_score() > 0.95);
}
}