Skip to main content

ai_lib_rust/guardrails/
result.rs

1//! Check result types
2
3use super::config::FilterAction;
4use serde::{Deserialize, Serialize};
5
6/// Type of violation detected
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8#[serde(rename_all = "snake_case")]
9pub enum ViolationType {
10    /// Keyword match
11    Keyword,
12    /// Regex pattern match
13    Pattern,
14    /// PII detection
15    Pii,
16    /// Custom rule
17    Custom,
18}
19
20/// A detected violation
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Violation {
23    /// Type of violation
24    pub violation_type: ViolationType,
25    /// The pattern or rule that matched
26    pub pattern: String,
27    /// Action associated with this violation
28    pub action: FilterAction,
29    /// Category of the violation
30    pub category: Option<String>,
31    /// Description of the violation
32    pub description: Option<String>,
33    /// The matched text (may be masked for sensitive data)
34    pub matched_text: Option<String>,
35}
36
37impl Violation {
38    /// Check if this violation should block the content
39    pub fn is_blocking(&self) -> bool {
40        matches!(self.action, FilterAction::Block)
41    }
42
43    /// Check if this violation is a warning
44    pub fn is_warning(&self) -> bool {
45        matches!(self.action, FilterAction::Warn)
46    }
47}
48
49/// Result of a content check
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct CheckResult {
52    /// List of violations found
53    violations: Vec<Violation>,
54    /// Whether the content should be blocked
55    blocked: bool,
56    /// Whether warnings were generated
57    warned: bool,
58}
59
60impl CheckResult {
61    /// Create a passed result (no violations)
62    pub fn passed() -> Self {
63        Self {
64            violations: Vec::new(),
65            blocked: false,
66            warned: false,
67        }
68    }
69
70    /// Create a result from violations
71    pub fn from_violations(violations: Vec<Violation>) -> Self {
72        let blocked = violations.iter().any(|v| v.is_blocking());
73        let warned = violations.iter().any(|v| v.is_warning());
74        
75        Self {
76            violations,
77            blocked,
78            warned,
79        }
80    }
81
82    /// Check if the content passed all checks
83    pub fn is_passed(&self) -> bool {
84        !self.blocked && self.violations.is_empty()
85    }
86
87    /// Check if the content was blocked
88    pub fn is_blocked(&self) -> bool {
89        self.blocked
90    }
91
92    /// Check if warnings were generated
93    pub fn is_warned(&self) -> bool {
94        self.warned
95    }
96
97    /// Check if there are any violations
98    pub fn has_violations(&self) -> bool {
99        !self.violations.is_empty()
100    }
101
102    /// Get the list of violations
103    pub fn violations(&self) -> &[Violation] {
104        &self.violations
105    }
106
107    /// Get blocking violations only
108    pub fn blocking_violations(&self) -> Vec<&Violation> {
109        self.violations.iter().filter(|v| v.is_blocking()).collect()
110    }
111
112    /// Get warning violations only
113    pub fn warning_violations(&self) -> Vec<&Violation> {
114        self.violations.iter().filter(|v| v.is_warning()).collect()
115    }
116
117    /// Merge another result into this one
118    pub fn merge(mut self, other: CheckResult) -> Self {
119        self.violations.extend(other.violations);
120        self.blocked = self.blocked || other.blocked;
121        self.warned = self.warned || other.warned;
122        self
123    }
124
125    /// Get a summary string
126    pub fn summary(&self) -> String {
127        if self.is_passed() {
128            "PASSED".to_string()
129        } else if self.is_blocked() {
130            format!("BLOCKED: {} violation(s)", self.violations.len())
131        } else if self.is_warned() {
132            format!("WARNING: {} violation(s)", self.violations.len())
133        } else {
134            format!("INFO: {} item(s) logged", self.violations.len())
135        }
136    }
137}
138
139impl std::fmt::Display for CheckResult {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        write!(f, "{}", self.summary())
142    }
143}