use super::types::ConfidenceLevel;
use super::types::{Mismatch, MismatchSeverity, Recommendation};
use std::collections::HashMap;
pub struct ConfidenceScorer;
impl ConfidenceScorer {
pub fn score_mismatch(mismatch: &Mismatch, context: &ScoringContext) -> f64 {
let mut score = 0.5;
let severity_factor = match mismatch.severity {
MismatchSeverity::Critical => 0.9,
MismatchSeverity::High => 0.8,
MismatchSeverity::Medium => 0.7,
MismatchSeverity::Low => 0.6,
MismatchSeverity::Info => 0.5,
};
score = (score + severity_factor) / 2.0;
let type_factor = match mismatch.mismatch_type {
super::types::MismatchType::MissingRequiredField => 0.95,
super::types::MismatchType::TypeMismatch => 0.9,
super::types::MismatchType::SchemaMismatch => 0.85,
super::types::MismatchType::FormatMismatch => 0.8,
super::types::MismatchType::ConstraintViolation => 0.75,
super::types::MismatchType::UnexpectedField => 0.7,
super::types::MismatchType::EndpointNotFound => 0.9,
super::types::MismatchType::MethodNotAllowed => 0.9,
super::types::MismatchType::HeaderMismatch => 0.7,
super::types::MismatchType::QueryParamMismatch => 0.65,
super::types::MismatchType::SemanticDescriptionChange => 0.6,
super::types::MismatchType::SemanticEnumNarrowing => 0.65,
super::types::MismatchType::SemanticNullabilityChange => 0.7,
super::types::MismatchType::SemanticErrorCodeRemoved => 0.75,
};
score = (score + type_factor) / 2.0;
if context.occurrence_count > 1 {
let consistency_boost = (context.occurrence_count as f64).min(10.0) / 10.0 * 0.2;
score = (score + consistency_boost).min(1.0);
}
if context.schema_clarity > 0.7 {
score = (score + 0.1).min(1.0);
} else if context.schema_clarity < 0.3 {
score = (score - 0.1).max(0.0);
}
score.clamp(0.0, 1.0)
}
pub fn score_recommendation(recommendation: &Recommendation, mismatch_confidence: f64) -> f64 {
let mut score = mismatch_confidence;
if recommendation.suggested_fix.is_some() {
score = (score + 0.1).min(1.0);
}
if recommendation.reasoning.is_some() {
score = (score + 0.05).min(1.0);
}
if recommendation.example.is_some() {
score = (score + 0.05).min(1.0);
}
score.min(recommendation.confidence)
}
pub fn score_correction(
correction: &super::types::CorrectionProposal,
mismatch_confidence: f64,
) -> f64 {
let mut score = mismatch_confidence;
if correction.before.is_some() && correction.after.is_some() {
score = (score + 0.1).min(1.0);
}
if correction.reasoning.is_some() {
score = (score + 0.05).min(1.0);
}
if correction.affected_endpoints.len() > 5 {
score = (score - 0.05).max(0.0);
}
score.min(correction.confidence)
}
pub fn get_confidence_level(score: f64) -> ConfidenceLevel {
ConfidenceLevel::from_score(score)
}
pub fn calculate_overall_confidence(mismatches: &[Mismatch]) -> f64 {
if mismatches.is_empty() {
return 1.0; }
let mut total_weighted_score = 0.0;
let mut total_weight = 0.0;
for mismatch in mismatches {
let weight = match mismatch.severity {
MismatchSeverity::Critical => 5.0,
MismatchSeverity::High => 4.0,
MismatchSeverity::Medium => 3.0,
MismatchSeverity::Low => 2.0,
MismatchSeverity::Info => 1.0,
};
total_weighted_score += mismatch.confidence * weight;
total_weight += weight;
}
if total_weight > 0.0 {
total_weighted_score / total_weight
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct ScoringContext {
pub occurrence_count: usize,
pub schema_clarity: f64,
pub factors: HashMap<String, f64>,
}
impl Default for ScoringContext {
fn default() -> Self {
Self {
occurrence_count: 1,
schema_clarity: 0.5,
factors: HashMap::new(),
}
}
}
impl ScoringContext {
pub fn new(occurrence_count: usize, schema_clarity: f64) -> Self {
Self {
occurrence_count,
schema_clarity,
factors: HashMap::new(),
}
}
pub fn with_factor(mut self, key: impl Into<String>, value: f64) -> Self {
self.factors.insert(key.into(), value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ai_contract_diff::types::{Mismatch, MismatchSeverity, MismatchType};
#[test]
fn test_confidence_level_from_score() {
assert_eq!(ConfidenceLevel::from_score(0.9), ConfidenceLevel::High);
assert_eq!(ConfidenceLevel::from_score(0.7), ConfidenceLevel::Medium);
assert_eq!(ConfidenceLevel::from_score(0.3), ConfidenceLevel::Low);
}
#[test]
fn test_score_mismatch() {
let mismatch = Mismatch {
mismatch_type: MismatchType::MissingRequiredField,
path: "/user/email".to_string(),
method: None,
expected: Some("string".to_string()),
actual: None,
description: "Missing required field".to_string(),
severity: MismatchSeverity::Critical,
confidence: 0.0, context: HashMap::new(),
};
let context = ScoringContext::new(1, 0.8);
let score = ConfidenceScorer::score_mismatch(&mismatch, &context);
assert!((0.0..=1.0).contains(&score));
assert!(score > 0.7); }
#[test]
fn test_calculate_overall_confidence() {
let mismatches = vec![
Mismatch {
mismatch_type: MismatchType::TypeMismatch,
path: "/user/age".to_string(),
method: None,
expected: Some("integer".to_string()),
actual: Some("string".to_string()),
description: "Type mismatch".to_string(),
severity: MismatchSeverity::High,
confidence: 0.9,
context: HashMap::new(),
},
Mismatch {
mismatch_type: MismatchType::UnexpectedField,
path: "/user/extra".to_string(),
method: None,
expected: None,
actual: Some("value".to_string()),
description: "Unexpected field".to_string(),
severity: MismatchSeverity::Low,
confidence: 0.6,
context: HashMap::new(),
},
];
let overall = ConfidenceScorer::calculate_overall_confidence(&mismatches);
assert!((0.0..=1.0).contains(&overall));
assert!(overall > 0.6); }
#[test]
fn test_empty_mismatches_confidence() {
let mismatches = vec![];
let overall = ConfidenceScorer::calculate_overall_confidence(&mismatches);
assert_eq!(overall, 1.0); }
}