Skip to main content

kya_validator/
policy_advanced.rs

1// Advanced Policy Engine with Rule Composition and Dynamic Loading
2
3use crate::types::{Manifest, PolicyContext};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8/// Comparison operators for rule conditions
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CompareOp {
11    #[serde(rename = "eq")]
12    Equals,
13    #[serde(rename = "ne")]
14    NotEquals,
15    #[serde(rename = "gt")]
16    GreaterThan,
17    #[serde(rename = "gte")]
18    GreaterThanOrEqual,
19    #[serde(rename = "lt")]
20    LessThan,
21    #[serde(rename = "lte")]
22    LessThanOrEqual,
23    #[serde(rename = "in")]
24    In,
25    #[serde(rename = "nin")]
26    NotIn,
27    #[serde(rename = "contains")]
28    Contains,
29}
30
31/// Rule condition types
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "type")]
34pub enum RuleCondition {
35    #[serde(rename = "json_pointer")]
36    JsonPointer {
37        pointer: String,
38        operator: CompareOp,
39        value: Value,
40    },
41    #[serde(rename = "crypto")]
42    Crypto { key_id: String, algorithm: String },
43    #[serde(rename = "resource")]
44    Resource { url: String, check: ResourceCheck },
45    #[serde(rename = "composite")]
46    Composite {
47        operator: LogicalOperator,
48        rules: Vec<RuleCondition>,
49    },
50}
51
52/// Logical operators for rule composition
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum LogicalOperator {
55    #[serde(rename = "and")]
56    And,
57    #[serde(rename = "or")]
58    Or,
59    #[serde(rename = "not")]
60    Not,
61}
62
63/// Resource check types
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum ResourceCheck {
66    #[serde(rename = "reachable")]
67    Reachable,
68    #[serde(rename = "content_contains")]
69    ContentContains { text: String },
70}
71
72/// Policy action when rule triggers
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum PolicyAction {
75    #[serde(rename = "allow")]
76    Allow,
77    #[serde(rename = "deny")]
78    Deny { reason: String },
79    #[serde(rename = "warn")]
80    Warn { message: String },
81    #[serde(rename = "log")]
82    Log { level: String },
83}
84
85/// Advanced policy rule
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PolicyRule {
88    pub name: String,
89    pub description: Option<String>,
90    pub condition: RuleCondition,
91    pub action: PolicyAction,
92    pub enabled: bool,
93    pub priority: i32,
94}
95
96/// Policy with multiple rules
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Policy {
99    pub id: String,
100    pub name: String,
101    pub version: String,
102    pub description: Option<String>,
103    pub rules: Vec<PolicyRule>,
104    pub metadata: Option<HashMap<String, Value>>,
105}
106
107/// Policy engine for evaluating policies
108#[derive(Debug, Clone)]
109pub struct PolicyEngine {
110    policies: Vec<Policy>,
111}
112
113impl PolicyEngine {
114    /// Create a new policy engine
115    pub fn new() -> Self {
116        Self {
117            policies: Vec::new(),
118        }
119    }
120
121    /// Add a policy to engine
122    pub fn add_policy(&mut self, policy: Policy) {
123        self.policies.push(policy);
124    }
125
126    /// Evaluate all policies against manifest and context
127    pub fn evaluate(
128        &self,
129        manifest: &Manifest,
130        _context: &PolicyContext,
131    ) -> PolicyEvaluationResult {
132        let mut results = Vec::new();
133        let mut denied = false;
134        let mut warnings = Vec::new();
135
136        // Sort rules by priority (higher priority first)
137        let mut all_rules: Vec<_> = self
138            .policies
139            .iter()
140            .flat_map(|p| p.rules.iter())
141            .filter(|r| r.enabled)
142            .collect();
143
144        all_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
145
146        for rule in all_rules {
147            let rule_result = self.evaluate_rule(rule, manifest);
148
149            let result = PolicyRuleResult {
150                rule_name: rule.name.clone(),
151                matched: rule_result.matched,
152                action: rule.action.clone(),
153                errors: rule_result.errors,
154            };
155
156            // Process action
157            match &rule.action {
158                PolicyAction::Deny { reason: _ } if rule_result.matched => {
159                    denied = true;
160                    results.push(result);
161                    break; // Stop on deny
162                }
163                PolicyAction::Warn { message } if rule_result.matched => {
164                    warnings.push(message.clone());
165                    results.push(result);
166                }
167                _ => {
168                    if rule_result.matched {
169                        results.push(result);
170                    }
171                }
172            }
173        }
174
175        PolicyEvaluationResult {
176            allowed: !denied,
177            warnings,
178            rule_results: results,
179        }
180    }
181
182    /// Evaluate a single rule
183    fn evaluate_rule(&self, rule: &PolicyRule, manifest: &Manifest) -> RuleEvaluationResult {
184        let mut errors = Vec::new();
185
186        let matched = match self.evaluate_condition(&rule.condition, manifest) {
187            Ok(m) => m,
188            Err(e) => {
189                errors.push(e);
190                false
191            }
192        };
193
194        RuleEvaluationResult { matched, errors }
195    }
196
197    /// Evaluate a rule condition
198    fn evaluate_condition(
199        &self,
200        condition: &RuleCondition,
201        manifest: &Manifest,
202    ) -> Result<bool, String> {
203        match condition {
204            RuleCondition::JsonPointer {
205                pointer,
206                operator,
207                value,
208            } => self.evaluate_json_pointer(manifest, pointer, *operator, value),
209            RuleCondition::Crypto { .. } => {
210                // Simplified crypto check for now
211                Ok(false)
212            }
213            RuleCondition::Resource { url, check } => self.evaluate_resource(url, check),
214            RuleCondition::Composite { operator, rules } => {
215                self.evaluate_composite(operator, rules, manifest)
216            }
217        }
218    }
219
220    /// Evaluate JSON pointer condition
221    fn evaluate_json_pointer(
222        &self,
223        manifest: &Manifest,
224        pointer: &str,
225        operator: CompareOp,
226        expected: &Value,
227    ) -> Result<bool, String> {
228        // Convert manifest to JSON Value
229        let manifest_json = serde_json::to_value(manifest)
230            .map_err(|e| format!("Failed to serialize manifest: {}", e))?;
231
232        // Try to navigate to the pointer manually for now
233        // Note: A proper JSON pointer library would be better
234        let actual = match pointer {
235            "/kyaVersion" => manifest_json
236                .get("kyaVersion")
237                .cloned()
238                .ok_or_else(|| "kyaVersion not found".to_string())?,
239            "/agentId" => manifest_json
240                .get("agentId")
241                .cloned()
242                .ok_or_else(|| "agentId not found".to_string())?,
243            _ => {
244                return Err(format!("JSON pointer '{}' not yet supported", pointer));
245            }
246        };
247
248        // Compare based on operator
249        match operator {
250            CompareOp::Equals => Ok(actual == *expected),
251            CompareOp::NotEquals => Ok(actual != *expected),
252            CompareOp::GreaterThan => {
253                let a = actual.as_f64().ok_or("Value is not a number")?;
254                let b = expected.as_f64().ok_or("Expected value is not a number")?;
255                Ok(a > b)
256            }
257            CompareOp::GreaterThanOrEqual => {
258                let a = actual.as_f64().ok_or("Value is not a number")?;
259                let b = expected.as_f64().ok_or("Expected value is not a number")?;
260                Ok(a >= b)
261            }
262            CompareOp::LessThan => {
263                let a = actual.as_f64().ok_or("Value is not a number")?;
264                let b = expected.as_f64().ok_or("Expected value is not a number")?;
265                Ok(a < b)
266            }
267            CompareOp::LessThanOrEqual => {
268                let a = actual.as_f64().ok_or("Value is not a number")?;
269                let b = expected.as_f64().ok_or("Expected value is not a number")?;
270                Ok(a <= b)
271            }
272            CompareOp::In => {
273                let arr = expected
274                    .as_array()
275                    .ok_or("Expected value is not an array")?;
276                Ok(arr.contains(&actual))
277            }
278            CompareOp::NotIn => {
279                let arr = expected
280                    .as_array()
281                    .ok_or("Expected value is not an array")?;
282                Ok(!arr.contains(&actual))
283            }
284            CompareOp::Contains => {
285                let str_actual = actual.as_str().ok_or("Value is not a string")?;
286                let str_expected = expected.as_str().ok_or("Expected value is not a string")?;
287                Ok(str_actual.contains(str_expected))
288            }
289        }
290    }
291
292    /// Evaluate resource condition
293    fn evaluate_resource(&self, _url: &str, check: &ResourceCheck) -> Result<bool, String> {
294        match check {
295            ResourceCheck::Reachable => {
296                // Check if URL is reachable (simplified, returns true)
297                Ok(true)
298            }
299            ResourceCheck::ContentContains { text: _ } => {
300                // Check content (simplified, returns true)
301                Ok(true)
302            }
303        }
304    }
305
306    /// Evaluate composite condition
307    fn evaluate_composite(
308        &self,
309        operator: &LogicalOperator,
310        rules: &[RuleCondition],
311        manifest: &Manifest,
312    ) -> Result<bool, String> {
313        match operator {
314            LogicalOperator::And => {
315                for rule in rules {
316                    if !self.evaluate_condition(rule, manifest)? {
317                        return Ok(false);
318                    }
319                }
320                Ok(true)
321            }
322            LogicalOperator::Or => {
323                for rule in rules {
324                    if self.evaluate_condition(rule, manifest)? {
325                        return Ok(true);
326                    }
327                }
328                Ok(false)
329            }
330            LogicalOperator::Not => {
331                if rules.len() != 1 {
332                    return Err("NOT operator requires exactly one rule".to_string());
333                }
334                Ok(!self.evaluate_condition(&rules[0], manifest)?)
335            }
336        }
337    }
338}
339
340/// Result of policy evaluation
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct PolicyEvaluationResult {
343    pub allowed: bool,
344    pub warnings: Vec<String>,
345    pub rule_results: Vec<PolicyRuleResult>,
346}
347
348/// Result of individual rule evaluation
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct PolicyRuleResult {
351    pub rule_name: String,
352    pub matched: bool,
353    pub action: PolicyAction,
354    pub errors: Vec<String>,
355}
356
357/// Internal rule evaluation result
358struct RuleEvaluationResult {
359    matched: bool,
360    errors: Vec<String>,
361}
362
363impl Default for PolicyEngine {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::types::Manifest;
373
374    #[test]
375    fn test_json_pointer_equals() {
376        let engine = PolicyEngine::new();
377        let manifest = Manifest {
378            kya_version: "1.0".to_string(),
379            agent_id: "did:key:z6Mk...".to_string(),
380            verification_method: None,
381            proof: vec![],
382            max_transaction_value: None,
383            permitted_regions: None,
384            forbidden_regions: None,
385        };
386
387        let result = engine.evaluate_json_pointer(
388            &manifest,
389            "/kyaVersion",
390            CompareOp::Equals,
391            &serde_json::json!("1.0"),
392        );
393        assert!(result.is_ok());
394        assert!(result.unwrap());
395    }
396
397    #[test]
398    fn test_policy_serialization() {
399        let policy = Policy {
400            id: "test-policy".to_string(),
401            name: "Test Policy".to_string(),
402            version: "1.0.0".to_string(),
403            description: Some("A test policy".to_string()),
404            rules: vec![],
405            metadata: None,
406        };
407
408        let json = serde_json::to_string(&policy);
409        assert!(json.is_ok());
410    }
411}