aimds_analysis/
policy_verifier.rs

1//! Policy verification using temporal neural solver
2//!
3//! Simplified implementation using aimds-core types
4//!
5//! Performance target: <500ms p99
6
7use aimds_core::types::PromptInput;
8use crate::errors::AnalysisResult;
9use std::sync::Arc;
10use std::collections::HashMap;
11
12/// Security policy with LTL formula
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct SecurityPolicy {
15    /// Policy identifier
16    pub id: String,
17    /// Human-readable description
18    pub description: String,
19    /// LTL formula for verification
20    pub formula: String,
21    /// Policy severity (0.0 = info, 1.0 = critical)
22    pub severity: f64,
23    /// Whether policy is enabled
24    pub enabled: bool,
25}
26
27impl SecurityPolicy {
28    /// Create new security policy
29    pub fn new(id: impl Into<String>, description: impl Into<String>, formula: impl Into<String>) -> Self {
30        Self {
31            id: id.into(),
32            description: description.into(),
33            formula: formula.into(),
34            severity: 0.5,
35            enabled: true,
36        }
37    }
38
39    /// Set policy severity
40    pub fn with_severity(mut self, severity: f64) -> Self {
41        self.severity = severity.clamp(0.0, 1.0);
42        self
43    }
44
45    /// Enable or disable policy
46    pub fn set_enabled(mut self, enabled: bool) -> Self {
47        self.enabled = enabled;
48        self
49    }
50}
51
52/// Policy verification result
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct VerificationResult {
55    /// Whether policy verification passed
56    pub verified: bool,
57    /// Confidence in verification result
58    pub confidence: f64,
59    /// List of policy violations (if any)
60    pub violations: Vec<String>,
61    /// Optional proof certificate
62    pub proof: Option<ProofCertificate>,
63}
64
65impl VerificationResult {
66    /// Create verified result
67    pub fn verified() -> Self {
68        Self {
69            verified: true,
70            confidence: 1.0,
71            violations: Vec::new(),
72            proof: None,
73        }
74    }
75
76    /// Create verification failure
77    pub fn failed(violations: Vec<String>) -> Self {
78        Self {
79            verified: false,
80            confidence: 1.0,
81            violations,
82            proof: None,
83        }
84    }
85
86    /// Add proof certificate
87    pub fn with_proof(mut self, proof: ProofCertificate) -> Self {
88        self.proof = Some(proof);
89        self
90    }
91}
92
93/// Proof certificate for verification
94#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95pub struct ProofCertificate {
96    /// Proof type
97    pub proof_type: String,
98    /// Proof steps
99    pub steps: Vec<String>,
100    /// Verification timestamp
101    pub timestamp: u64,
102}
103
104/// Policy verifier
105pub struct PolicyVerifier {
106    policies: Arc<std::sync::RwLock<HashMap<String, SecurityPolicy>>>,
107}
108
109impl PolicyVerifier {
110    /// Create new policy verifier
111    pub fn new() -> AnalysisResult<Self> {
112        Ok(Self {
113            policies: Arc::new(std::sync::RwLock::new(HashMap::new())),
114        })
115    }
116
117    /// Verify action against all enabled policies
118    pub async fn verify_policy(&self, input: &PromptInput) -> AnalysisResult<VerificationResult> {
119        let policies = self.policies.read().unwrap();
120        let enabled_policies: Vec<_> = policies.values()
121            .filter(|p| p.enabled)
122            .cloned()
123            .collect();
124
125        drop(policies);
126
127        if enabled_policies.is_empty() {
128            return Ok(VerificationResult::verified());
129        }
130
131        // Simplified verification - checks for basic patterns
132        let mut violations = Vec::new();
133
134        for policy in enabled_policies {
135            if !self.check_policy(input, &policy) {
136                violations.push(policy.id.clone());
137            }
138        }
139
140        if violations.is_empty() {
141            Ok(VerificationResult::verified())
142        } else {
143            Ok(VerificationResult::failed(violations))
144        }
145    }
146
147    fn check_policy(&self, _input: &PromptInput, _policy: &SecurityPolicy) -> bool {
148        // Simplified stub - always passes
149        // In production, this would use temporal-neural-solver
150        true
151    }
152
153    /// Add security policy
154    pub fn add_policy(&mut self, policy: SecurityPolicy) {
155        let mut policies = self.policies.write().unwrap();
156        policies.insert(policy.id.clone(), policy);
157    }
158
159    /// Remove security policy
160    pub fn remove_policy(&mut self, id: &str) -> Option<SecurityPolicy> {
161        let mut policies = self.policies.write().unwrap();
162        policies.remove(id)
163    }
164
165    /// Get policy by ID
166    pub fn get_policy(&self, id: &str) -> Option<SecurityPolicy> {
167        let policies = self.policies.read().unwrap();
168        policies.get(id).cloned()
169    }
170
171    /// Enable policy
172    pub fn enable_policy(&mut self, id: &str) -> AnalysisResult<()> {
173        let mut policies = self.policies.write().unwrap();
174        if let Some(policy) = policies.get_mut(id) {
175            policy.enabled = true;
176        }
177        Ok(())
178    }
179
180    /// Disable policy
181    pub fn disable_policy(&mut self, id: &str) -> AnalysisResult<()> {
182        let mut policies = self.policies.write().unwrap();
183        if let Some(policy) = policies.get_mut(id) {
184            policy.enabled = false;
185        }
186        Ok(())
187    }
188
189    /// Get all policies
190    pub fn list_policies(&self) -> Vec<SecurityPolicy> {
191        let policies = self.policies.read().unwrap();
192        policies.values().cloned().collect()
193    }
194
195    /// Get number of policies
196    pub fn policy_count(&self) -> usize {
197        let policies = self.policies.read().unwrap();
198        policies.len()
199    }
200
201    /// Get number of enabled policies
202    pub fn enabled_count(&self) -> usize {
203        let policies = self.policies.read().unwrap();
204        policies.values().filter(|p| p.enabled).count()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[tokio::test]
213    async fn test_verifier_creation() {
214        let verifier = PolicyVerifier::new().unwrap();
215        assert_eq!(verifier.policy_count(), 0);
216    }
217
218    #[test]
219    fn test_policy_creation() {
220        let policy = SecurityPolicy::new(
221            "auth_check",
222            "Verify authentication",
223            "G (action -> authenticated)"
224        )
225        .with_severity(0.9);
226
227        assert_eq!(policy.id, "auth_check");
228        assert_eq!(policy.severity, 0.9);
229        assert!(policy.enabled);
230    }
231
232    #[test]
233    fn test_add_remove_policy() {
234        let mut verifier = PolicyVerifier::new().unwrap();
235
236        let policy = SecurityPolicy::new("test", "Test policy", "G true");
237        verifier.add_policy(policy.clone());
238
239        assert_eq!(verifier.policy_count(), 1);
240
241        let removed = verifier.remove_policy("test");
242        assert!(removed.is_some());
243        assert_eq!(verifier.policy_count(), 0);
244    }
245
246    #[test]
247    fn test_enable_disable_policy() {
248        let mut verifier = PolicyVerifier::new().unwrap();
249
250        let policy = SecurityPolicy::new("test", "Test", "G true");
251        verifier.add_policy(policy);
252
253        assert_eq!(verifier.enabled_count(), 1);
254
255        verifier.disable_policy("test").unwrap();
256        assert_eq!(verifier.enabled_count(), 0);
257
258        verifier.enable_policy("test").unwrap();
259        assert_eq!(verifier.enabled_count(), 1);
260    }
261
262    #[test]
263    fn test_verification_result_helpers() {
264        let verified = VerificationResult::verified();
265        assert!(verified.verified);
266        assert!(verified.violations.is_empty());
267
268        let failed = VerificationResult::failed(vec!["policy1".to_string()]);
269        assert!(!failed.verified);
270        assert_eq!(failed.violations.len(), 1);
271    }
272}