aimds_analysis/
policy_verifier.rs1use aimds_core::types::PromptInput;
8use crate::errors::AnalysisResult;
9use std::sync::Arc;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct SecurityPolicy {
15 pub id: String,
17 pub description: String,
19 pub formula: String,
21 pub severity: f64,
23 pub enabled: bool,
25}
26
27impl SecurityPolicy {
28 pub fn new(id: impl Into<String>, description: impl Into<String>, formula: impl Into<String>) -> Self {
30 Self {
31 id: id.into(),
32 description: description.into(),
33 formula: formula.into(),
34 severity: 0.5,
35 enabled: true,
36 }
37 }
38
39 pub fn with_severity(mut self, severity: f64) -> Self {
41 self.severity = severity.clamp(0.0, 1.0);
42 self
43 }
44
45 pub fn set_enabled(mut self, enabled: bool) -> Self {
47 self.enabled = enabled;
48 self
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct VerificationResult {
55 pub verified: bool,
57 pub confidence: f64,
59 pub violations: Vec<String>,
61 pub proof: Option<ProofCertificate>,
63}
64
65impl VerificationResult {
66 pub fn verified() -> Self {
68 Self {
69 verified: true,
70 confidence: 1.0,
71 violations: Vec::new(),
72 proof: None,
73 }
74 }
75
76 pub fn failed(violations: Vec<String>) -> Self {
78 Self {
79 verified: false,
80 confidence: 1.0,
81 violations,
82 proof: None,
83 }
84 }
85
86 pub fn with_proof(mut self, proof: ProofCertificate) -> Self {
88 self.proof = Some(proof);
89 self
90 }
91}
92
93#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95pub struct ProofCertificate {
96 pub proof_type: String,
98 pub steps: Vec<String>,
100 pub timestamp: u64,
102}
103
104pub struct PolicyVerifier {
106 policies: Arc<std::sync::RwLock<HashMap<String, SecurityPolicy>>>,
107}
108
109impl PolicyVerifier {
110 pub fn new() -> AnalysisResult<Self> {
112 Ok(Self {
113 policies: Arc::new(std::sync::RwLock::new(HashMap::new())),
114 })
115 }
116
117 pub async fn verify_policy(&self, input: &PromptInput) -> AnalysisResult<VerificationResult> {
119 let policies = self.policies.read().unwrap();
120 let enabled_policies: Vec<_> = policies.values()
121 .filter(|p| p.enabled)
122 .cloned()
123 .collect();
124
125 drop(policies);
126
127 if enabled_policies.is_empty() {
128 return Ok(VerificationResult::verified());
129 }
130
131 let mut violations = Vec::new();
133
134 for policy in enabled_policies {
135 if !self.check_policy(input, &policy) {
136 violations.push(policy.id.clone());
137 }
138 }
139
140 if violations.is_empty() {
141 Ok(VerificationResult::verified())
142 } else {
143 Ok(VerificationResult::failed(violations))
144 }
145 }
146
147 fn check_policy(&self, _input: &PromptInput, _policy: &SecurityPolicy) -> bool {
148 true
151 }
152
153 pub fn add_policy(&mut self, policy: SecurityPolicy) {
155 let mut policies = self.policies.write().unwrap();
156 policies.insert(policy.id.clone(), policy);
157 }
158
159 pub fn remove_policy(&mut self, id: &str) -> Option<SecurityPolicy> {
161 let mut policies = self.policies.write().unwrap();
162 policies.remove(id)
163 }
164
165 pub fn get_policy(&self, id: &str) -> Option<SecurityPolicy> {
167 let policies = self.policies.read().unwrap();
168 policies.get(id).cloned()
169 }
170
171 pub fn enable_policy(&mut self, id: &str) -> AnalysisResult<()> {
173 let mut policies = self.policies.write().unwrap();
174 if let Some(policy) = policies.get_mut(id) {
175 policy.enabled = true;
176 }
177 Ok(())
178 }
179
180 pub fn disable_policy(&mut self, id: &str) -> AnalysisResult<()> {
182 let mut policies = self.policies.write().unwrap();
183 if let Some(policy) = policies.get_mut(id) {
184 policy.enabled = false;
185 }
186 Ok(())
187 }
188
189 pub fn list_policies(&self) -> Vec<SecurityPolicy> {
191 let policies = self.policies.read().unwrap();
192 policies.values().cloned().collect()
193 }
194
195 pub fn policy_count(&self) -> usize {
197 let policies = self.policies.read().unwrap();
198 policies.len()
199 }
200
201 pub fn enabled_count(&self) -> usize {
203 let policies = self.policies.read().unwrap();
204 policies.values().filter(|p| p.enabled).count()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[tokio::test]
213 async fn test_verifier_creation() {
214 let verifier = PolicyVerifier::new().unwrap();
215 assert_eq!(verifier.policy_count(), 0);
216 }
217
218 #[test]
219 fn test_policy_creation() {
220 let policy = SecurityPolicy::new(
221 "auth_check",
222 "Verify authentication",
223 "G (action -> authenticated)"
224 )
225 .with_severity(0.9);
226
227 assert_eq!(policy.id, "auth_check");
228 assert_eq!(policy.severity, 0.9);
229 assert!(policy.enabled);
230 }
231
232 #[test]
233 fn test_add_remove_policy() {
234 let mut verifier = PolicyVerifier::new().unwrap();
235
236 let policy = SecurityPolicy::new("test", "Test policy", "G true");
237 verifier.add_policy(policy.clone());
238
239 assert_eq!(verifier.policy_count(), 1);
240
241 let removed = verifier.remove_policy("test");
242 assert!(removed.is_some());
243 assert_eq!(verifier.policy_count(), 0);
244 }
245
246 #[test]
247 fn test_enable_disable_policy() {
248 let mut verifier = PolicyVerifier::new().unwrap();
249
250 let policy = SecurityPolicy::new("test", "Test", "G true");
251 verifier.add_policy(policy);
252
253 assert_eq!(verifier.enabled_count(), 1);
254
255 verifier.disable_policy("test").unwrap();
256 assert_eq!(verifier.enabled_count(), 0);
257
258 verifier.enable_policy("test").unwrap();
259 assert_eq!(verifier.enabled_count(), 1);
260 }
261
262 #[test]
263 fn test_verification_result_helpers() {
264 let verified = VerificationResult::verified();
265 assert!(verified.verified);
266 assert!(verified.violations.is_empty());
267
268 let failed = VerificationResult::failed(vec!["policy1".to_string()]);
269 assert!(!failed.verified);
270 assert_eq!(failed.violations.len(), 1);
271 }
272}