Skip to main content

datasynth_eval/gates/
engine.rs

1//! Quality gate evaluation engine.
2//!
3//! Evaluates generation results against configurable pass/fail criteria.
4
5use serde::{Deserialize, Serialize};
6
7use crate::ComprehensiveEvaluation;
8
9/// A quality metric that can be checked by a gate.
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(rename_all = "snake_case")]
12pub enum QualityMetric {
13    /// Benford's Law Mean Absolute Deviation.
14    BenfordMad,
15    /// Balance sheet coherence rate (0.0–1.0).
16    BalanceCoherence,
17    /// Document chain integrity rate (0.0–1.0).
18    DocumentChainIntegrity,
19    /// Correlation preservation score (0.0–1.0).
20    CorrelationPreservation,
21    /// Temporal consistency score (0.0–1.0).
22    TemporalConsistency,
23    /// Privacy MIA AUC-ROC score.
24    PrivacyMiaAuc,
25    /// Data completion rate (0.0–1.0).
26    CompletionRate,
27    /// Duplicate rate (0.0–1.0).
28    DuplicateRate,
29    /// Referential integrity rate (0.0–1.0).
30    ReferentialIntegrity,
31    /// Intercompany match rate (0.0–1.0).
32    IcMatchRate,
33    /// Custom metric identified by name.
34    Custom(String),
35}
36
37impl std::fmt::Display for QualityMetric {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::BenfordMad => write!(f, "benford_mad"),
41            Self::BalanceCoherence => write!(f, "balance_coherence"),
42            Self::DocumentChainIntegrity => write!(f, "document_chain_integrity"),
43            Self::CorrelationPreservation => write!(f, "correlation_preservation"),
44            Self::TemporalConsistency => write!(f, "temporal_consistency"),
45            Self::PrivacyMiaAuc => write!(f, "privacy_mia_auc"),
46            Self::CompletionRate => write!(f, "completion_rate"),
47            Self::DuplicateRate => write!(f, "duplicate_rate"),
48            Self::ReferentialIntegrity => write!(f, "referential_integrity"),
49            Self::IcMatchRate => write!(f, "ic_match_rate"),
50            Self::Custom(name) => write!(f, "custom:{}", name),
51        }
52    }
53}
54
55/// Comparison operator for threshold checks.
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57#[serde(rename_all = "snake_case")]
58pub enum Comparison {
59    /// Greater than or equal to threshold.
60    Gte,
61    /// Less than or equal to threshold.
62    Lte,
63    /// Equal to threshold (with epsilon).
64    Eq,
65    /// Between two thresholds (inclusive). Uses `threshold` as lower and `upper_threshold` as upper.
66    Between,
67}
68
69/// Strategy for handling gate failures.
70#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
71#[serde(rename_all = "snake_case")]
72pub enum FailStrategy {
73    /// Stop checking on first failure.
74    FailFast,
75    /// Check all gates and collect all failures.
76    #[default]
77    CollectAll,
78}
79
80/// A single quality gate with a metric, threshold, and comparison.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct QualityGate {
83    /// Human-readable name for this gate.
84    pub name: String,
85    /// The metric to check.
86    pub metric: QualityMetric,
87    /// Threshold value for comparison.
88    pub threshold: f64,
89    /// Upper threshold for Between comparison.
90    #[serde(default, skip_serializing_if = "Option::is_none")]
91    pub upper_threshold: Option<f64>,
92    /// How to compare the metric value against the threshold.
93    pub comparison: Comparison,
94}
95
96impl QualityGate {
97    /// Create a new quality gate.
98    pub fn new(
99        name: impl Into<String>,
100        metric: QualityMetric,
101        threshold: f64,
102        comparison: Comparison,
103    ) -> Self {
104        Self {
105            name: name.into(),
106            metric,
107            threshold,
108            upper_threshold: None,
109            comparison,
110        }
111    }
112
113    /// Create a gate that requires metric >= threshold.
114    pub fn gte(name: impl Into<String>, metric: QualityMetric, threshold: f64) -> Self {
115        Self::new(name, metric, threshold, Comparison::Gte)
116    }
117
118    /// Create a gate that requires metric <= threshold.
119    pub fn lte(name: impl Into<String>, metric: QualityMetric, threshold: f64) -> Self {
120        Self::new(name, metric, threshold, Comparison::Lte)
121    }
122
123    /// Create a gate that requires metric between lower and upper (inclusive).
124    pub fn between(name: impl Into<String>, metric: QualityMetric, lower: f64, upper: f64) -> Self {
125        Self {
126            name: name.into(),
127            metric,
128            threshold: lower,
129            upper_threshold: Some(upper),
130            comparison: Comparison::Between,
131        }
132    }
133
134    /// Check if an actual value passes this gate.
135    pub fn check(&self, actual: f64) -> bool {
136        match self.comparison {
137            Comparison::Gte => actual >= self.threshold,
138            Comparison::Lte => actual <= self.threshold,
139            Comparison::Eq => (actual - self.threshold).abs() < 1e-9,
140            Comparison::Between => {
141                let upper = self.upper_threshold.unwrap_or(self.threshold);
142                actual >= self.threshold && actual <= upper
143            }
144        }
145    }
146}
147
148/// A named collection of quality gates.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct GateProfile {
151    /// Profile name (e.g., "strict", "default", "lenient").
152    pub name: String,
153    /// List of quality gates in this profile.
154    pub gates: Vec<QualityGate>,
155    /// Strategy for handling failures.
156    #[serde(default)]
157    pub fail_strategy: FailStrategy,
158}
159
160impl GateProfile {
161    /// Create a new gate profile.
162    pub fn new(name: impl Into<String>, gates: Vec<QualityGate>) -> Self {
163        Self {
164            name: name.into(),
165            gates,
166            fail_strategy: FailStrategy::default(),
167        }
168    }
169
170    /// Set the fail strategy.
171    pub fn with_fail_strategy(mut self, strategy: FailStrategy) -> Self {
172        self.fail_strategy = strategy;
173        self
174    }
175}
176
177/// Result of checking a single gate.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct GateCheckResult {
180    /// Gate name.
181    pub gate_name: String,
182    /// Metric checked.
183    pub metric: QualityMetric,
184    /// Whether the gate passed.
185    pub passed: bool,
186    /// Actual metric value.
187    pub actual_value: Option<f64>,
188    /// Expected threshold.
189    pub threshold: f64,
190    /// Comparison used.
191    pub comparison: Comparison,
192    /// Human-readable message.
193    pub message: String,
194}
195
196/// Overall result of evaluating all gates in a profile.
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GateResult {
199    /// Whether all gates passed.
200    pub passed: bool,
201    /// Profile name used.
202    pub profile_name: String,
203    /// Individual gate results.
204    pub results: Vec<GateCheckResult>,
205    /// Summary message.
206    pub summary: String,
207    /// Number of gates that passed.
208    pub gates_passed: usize,
209    /// Total number of gates checked.
210    pub gates_total: usize,
211}
212
213/// Engine that evaluates quality gates against a comprehensive evaluation.
214pub struct GateEngine;
215
216impl GateEngine {
217    /// Evaluate a comprehensive evaluation against a gate profile.
218    pub fn evaluate(evaluation: &ComprehensiveEvaluation, profile: &GateProfile) -> GateResult {
219        let mut results = Vec::new();
220        let mut all_passed = true;
221
222        for gate in &profile.gates {
223            let (actual_value, message) = Self::extract_metric(evaluation, &gate.metric);
224
225            let check_result = match actual_value {
226                Some(value) => {
227                    let passed = gate.check(value);
228                    if !passed {
229                        all_passed = false;
230                    }
231                    GateCheckResult {
232                        gate_name: gate.name.clone(),
233                        metric: gate.metric.clone(),
234                        passed,
235                        actual_value: Some(value),
236                        threshold: gate.threshold,
237                        comparison: gate.comparison.clone(),
238                        message: if passed {
239                            format!(
240                                "{}: {:.4} passes {:?} {:.4}",
241                                gate.name, value, gate.comparison, gate.threshold
242                            )
243                        } else {
244                            format!(
245                                "{}: {:.4} fails {:?} {:.4}",
246                                gate.name, value, gate.comparison, gate.threshold
247                            )
248                        },
249                    }
250                }
251                None => {
252                    // Metric not available - treat as not applicable (pass)
253                    GateCheckResult {
254                        gate_name: gate.name.clone(),
255                        metric: gate.metric.clone(),
256                        passed: true,
257                        actual_value: None,
258                        threshold: gate.threshold,
259                        comparison: gate.comparison.clone(),
260                        message: format!("{}: metric not available ({})", gate.name, message),
261                    }
262                }
263            };
264
265            let failed = !check_result.passed;
266            results.push(check_result);
267
268            if failed && profile.fail_strategy == FailStrategy::FailFast {
269                break;
270            }
271        }
272
273        let gates_passed = results.iter().filter(|r| r.passed).count();
274        let gates_total = results.len();
275
276        let summary = if all_passed {
277            format!(
278                "All {}/{} quality gates passed (profile: {})",
279                gates_passed, gates_total, profile.name
280            )
281        } else {
282            let failed_names: Vec<_> = results
283                .iter()
284                .filter(|r| !r.passed)
285                .map(|r| r.gate_name.as_str())
286                .collect();
287            format!(
288                "{}/{} quality gates passed, {} failed: {} (profile: {})",
289                gates_passed,
290                gates_total,
291                gates_total - gates_passed,
292                failed_names.join(", "),
293                profile.name
294            )
295        };
296
297        GateResult {
298            passed: all_passed,
299            profile_name: profile.name.clone(),
300            results,
301            summary,
302            gates_passed,
303            gates_total,
304        }
305    }
306
307    /// Extract a metric value from a comprehensive evaluation.
308    fn extract_metric(
309        evaluation: &ComprehensiveEvaluation,
310        metric: &QualityMetric,
311    ) -> (Option<f64>, String) {
312        match metric {
313            QualityMetric::BenfordMad => {
314                let mad = evaluation.statistical.benford.as_ref().map(|b| b.mad);
315                (mad, "benford analysis not available".to_string())
316            }
317            QualityMetric::BalanceCoherence => {
318                let rate = evaluation.coherence.balance.as_ref().map(|b| {
319                    if b.equation_balanced {
320                        1.0
321                    } else {
322                        0.0
323                    }
324                });
325                (rate, "balance sheet evaluation not available".to_string())
326            }
327            QualityMetric::DocumentChainIntegrity => {
328                let rate = evaluation
329                    .coherence
330                    .document_chain
331                    .as_ref()
332                    .map(|d| d.p2p_completion_rate);
333                (rate, "document chain evaluation not available".to_string())
334            }
335            QualityMetric::CorrelationPreservation => {
336                // Not directly available in ComprehensiveEvaluation - return None
337                (
338                    None,
339                    "correlation preservation metric not available".to_string(),
340                )
341            }
342            QualityMetric::TemporalConsistency => {
343                let rate = evaluation
344                    .statistical
345                    .temporal
346                    .as_ref()
347                    .map(|t| t.pattern_correlation);
348                (rate, "temporal analysis not available".to_string())
349            }
350            QualityMetric::PrivacyMiaAuc => {
351                let auc = evaluation
352                    .privacy
353                    .as_ref()
354                    .and_then(|p| p.membership_inference.as_ref())
355                    .map(|m| m.auc_roc);
356                (auc, "privacy MIA evaluation not available".to_string())
357            }
358            QualityMetric::CompletionRate => {
359                let rate = evaluation
360                    .quality
361                    .completeness
362                    .as_ref()
363                    .map(|c| c.overall_completeness);
364                (rate, "completeness analysis not available".to_string())
365            }
366            QualityMetric::DuplicateRate => {
367                let rate = evaluation
368                    .quality
369                    .uniqueness
370                    .as_ref()
371                    .map(|u| u.duplicate_rate);
372                (rate, "uniqueness analysis not available".to_string())
373            }
374            QualityMetric::ReferentialIntegrity => {
375                let rate = evaluation
376                    .coherence
377                    .referential
378                    .as_ref()
379                    .map(|r| r.overall_integrity_score);
380                (
381                    rate,
382                    "referential integrity evaluation not available".to_string(),
383                )
384            }
385            QualityMetric::IcMatchRate => {
386                let rate = evaluation
387                    .coherence
388                    .intercompany
389                    .as_ref()
390                    .map(|ic| ic.match_rate);
391                (rate, "IC matching evaluation not available".to_string())
392            }
393            QualityMetric::Custom(name) => (
394                None,
395                format!(
396                    "custom metric '{}' not available in standard evaluation",
397                    name
398                ),
399            ),
400        }
401    }
402}
403
404#[cfg(test)]
405#[allow(clippy::unwrap_used)]
406mod tests {
407    use super::*;
408
409    fn sample_profile() -> GateProfile {
410        GateProfile::new(
411            "test",
412            vec![
413                QualityGate::lte("benford_compliance", QualityMetric::BenfordMad, 0.015),
414                QualityGate::gte("completeness", QualityMetric::CompletionRate, 0.95),
415            ],
416        )
417    }
418
419    #[test]
420    fn test_gate_check_gte() {
421        let gate = QualityGate::gte("test", QualityMetric::CompletionRate, 0.95);
422        assert!(gate.check(0.96));
423        assert!(gate.check(0.95));
424        assert!(!gate.check(0.94));
425    }
426
427    #[test]
428    fn test_gate_check_lte() {
429        let gate = QualityGate::lte("test", QualityMetric::BenfordMad, 0.015);
430        assert!(gate.check(0.01));
431        assert!(gate.check(0.015));
432        assert!(!gate.check(0.016));
433    }
434
435    #[test]
436    fn test_gate_check_between() {
437        let gate = QualityGate::between("test", QualityMetric::DuplicateRate, 0.0, 0.05);
438        assert!(gate.check(0.0));
439        assert!(gate.check(0.03));
440        assert!(gate.check(0.05));
441        assert!(!gate.check(0.06));
442    }
443
444    #[test]
445    fn test_gate_check_eq() {
446        let gate = QualityGate::new("test", QualityMetric::BalanceCoherence, 1.0, Comparison::Eq);
447        assert!(gate.check(1.0));
448        assert!(!gate.check(0.99));
449    }
450
451    #[test]
452    fn test_evaluate_empty_evaluation() {
453        let evaluation = ComprehensiveEvaluation::new();
454        let profile = sample_profile();
455        let result = GateEngine::evaluate(&evaluation, &profile);
456        // All metrics unavailable → treated as pass
457        assert!(result.passed);
458        assert_eq!(result.gates_total, 2);
459    }
460
461    #[test]
462    fn test_fail_fast_stops_on_first_failure() {
463        let evaluation = ComprehensiveEvaluation::new();
464        let profile = GateProfile::new(
465            "strict",
466            vec![
467                // This will fail because balance_coherence is not available
468                // but N/A is treated as pass. Let's create a custom gate
469                // that we know will fail
470                QualityGate::gte(
471                    "custom_gate",
472                    QualityMetric::Custom("nonexistent".to_string()),
473                    0.99,
474                ),
475                QualityGate::gte(
476                    "another",
477                    QualityMetric::Custom("also_nonexistent".to_string()),
478                    0.99,
479                ),
480            ],
481        )
482        .with_fail_strategy(FailStrategy::FailFast);
483
484        let result = GateEngine::evaluate(&evaluation, &profile);
485        // Custom metrics unavailable are treated as pass, so both pass
486        assert!(result.passed);
487    }
488
489    #[test]
490    fn test_collect_all_reports_all_failures() {
491        let evaluation = ComprehensiveEvaluation::new();
492        let profile = GateProfile::new(
493            "test",
494            vec![
495                QualityGate::lte("mad", QualityMetric::BenfordMad, 0.015),
496                QualityGate::gte("completion", QualityMetric::CompletionRate, 0.95),
497            ],
498        )
499        .with_fail_strategy(FailStrategy::CollectAll);
500
501        let result = GateEngine::evaluate(&evaluation, &profile);
502        assert_eq!(result.results.len(), 2);
503    }
504
505    #[test]
506    fn test_gate_result_summary() {
507        let evaluation = ComprehensiveEvaluation::new();
508        let profile = sample_profile();
509        let result = GateEngine::evaluate(&evaluation, &profile);
510        assert!(result.summary.contains("test"));
511    }
512
513    #[test]
514    fn test_quality_metric_display() {
515        assert_eq!(QualityMetric::BenfordMad.to_string(), "benford_mad");
516        assert_eq!(
517            QualityMetric::BalanceCoherence.to_string(),
518            "balance_coherence"
519        );
520        assert_eq!(
521            QualityMetric::Custom("my_metric".to_string()).to_string(),
522            "custom:my_metric"
523        );
524    }
525
526    #[test]
527    fn test_gate_profile_serialization() {
528        let profile = sample_profile();
529        let json = serde_json::to_string(&profile).expect("serialize");
530        let deserialized: GateProfile = serde_json::from_str(&json).expect("deserialize");
531        assert_eq!(deserialized.name, "test");
532        assert_eq!(deserialized.gates.len(), 2);
533    }
534}