use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use super::assessment::IntegrityAssessment;
use super::counts::ViolationCounts;
use super::violation::{MetamorphicRelationType, MetamorphicViolation};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BehavioralIntegrity {
pub equivalence_score: f64,
pub syscall_match: f64,
pub timing_variance: f64,
pub semantic_equiv: f64,
pub violations: Vec<MetamorphicViolation>,
pub timestamp: DateTime<Utc>,
pub test_count: u32,
pub model_id: String,
}
impl BehavioralIntegrity {
pub fn new(
equivalence_score: f64,
syscall_match: f64,
timing_variance: f64,
semantic_equiv: f64,
model_id: impl Into<String>,
) -> Self {
Self {
equivalence_score: equivalence_score.clamp(0.0, 1.0),
syscall_match: syscall_match.clamp(0.0, 1.0),
timing_variance: timing_variance.clamp(0.0, 1.0),
semantic_equiv: semantic_equiv.clamp(0.0, 1.0),
violations: Vec::new(),
timestamp: Utc::now(),
test_count: 0,
model_id: model_id.into(),
}
}
pub fn perfect(model_id: impl Into<String>) -> Self {
Self::new(1.0, 1.0, 0.0, 1.0, model_id)
}
pub fn add_violation(&mut self, violation: MetamorphicViolation) {
self.violations.push(violation);
}
pub fn with_test_count(mut self, count: u32) -> Self {
self.test_count = count;
self
}
pub fn composite_score(&self) -> f64 {
const W_EQUIV: f64 = 0.3;
const W_SYSCALL: f64 = 0.2;
const W_TIMING: f64 = 0.2;
const W_SEMANTIC: f64 = 0.3;
let timing_score = 1.0 - self.timing_variance;
W_EQUIV * self.equivalence_score
+ W_SYSCALL * self.syscall_match
+ W_TIMING * timing_score
+ W_SEMANTIC * self.semantic_equiv
}
pub fn passes_gate(&self, threshold: f64) -> bool {
self.composite_score() >= threshold
&& !self.has_critical_violations()
&& self.timing_variance < 0.2
}
pub fn has_critical_violations(&self) -> bool {
self.violations.iter().any(MetamorphicViolation::is_critical)
}
pub fn violation_counts(&self) -> ViolationCounts {
let critical = self.violations.iter().filter(|v| v.is_critical()).count() as u32;
let warnings =
self.violations.iter().filter(|v| v.is_warning() && !v.is_critical()).count() as u32;
let minor = self.violations.iter().filter(|v| !v.is_warning()).count() as u32;
ViolationCounts { critical, warnings, minor, total: self.violations.len() as u32 }
}
pub fn violations_by_type(
&self,
) -> std::collections::HashMap<MetamorphicRelationType, Vec<&MetamorphicViolation>> {
let mut map = std::collections::HashMap::new();
for v in &self.violations {
map.entry(v.relation_type).or_insert_with(Vec::new).push(v);
}
map
}
pub fn most_severe_violation(&self) -> Option<&MetamorphicViolation> {
self.violations
.iter()
.max_by(|a, b| a.severity.partial_cmp(&b.severity).unwrap_or(std::cmp::Ordering::Equal))
}
pub fn assessment(&self) -> IntegrityAssessment {
let score = self.composite_score();
let counts = self.violation_counts();
if counts.critical > 0 {
IntegrityAssessment::Critical
} else if score < 0.5 {
IntegrityAssessment::Poor
} else if score < 0.7 {
IntegrityAssessment::Fair
} else if score < 0.9 {
IntegrityAssessment::Good
} else {
IntegrityAssessment::Excellent
}
}
pub fn summary(&self) -> String {
let counts = self.violation_counts();
format!(
"Model: {}\n\
Composite Score: {:.1}%\n\
Assessment: {}\n\
Violations: {} critical, {} warnings, {} minor\n\
Tests Run: {}\n\
Gate Status: {}",
self.model_id,
self.composite_score() * 100.0,
self.assessment(),
counts.critical,
counts.warnings,
counts.minor,
self.test_count,
if self.passes_gate(0.9) { "PASS" } else { "FAIL" }
)
}
}