1use crate::types::{Manifest, PolicyContext};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CompareOp {
11 #[serde(rename = "eq")]
12 Equals,
13 #[serde(rename = "ne")]
14 NotEquals,
15 #[serde(rename = "gt")]
16 GreaterThan,
17 #[serde(rename = "gte")]
18 GreaterThanOrEqual,
19 #[serde(rename = "lt")]
20 LessThan,
21 #[serde(rename = "lte")]
22 LessThanOrEqual,
23 #[serde(rename = "in")]
24 In,
25 #[serde(rename = "nin")]
26 NotIn,
27 #[serde(rename = "contains")]
28 Contains,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "type")]
34pub enum RuleCondition {
35 #[serde(rename = "json_pointer")]
36 JsonPointer {
37 pointer: String,
38 operator: CompareOp,
39 value: Value,
40 },
41 #[serde(rename = "crypto")]
42 Crypto { key_id: String, algorithm: String },
43 #[serde(rename = "resource")]
44 Resource { url: String, check: ResourceCheck },
45 #[serde(rename = "composite")]
46 Composite {
47 operator: LogicalOperator,
48 rules: Vec<RuleCondition>,
49 },
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum LogicalOperator {
55 #[serde(rename = "and")]
56 And,
57 #[serde(rename = "or")]
58 Or,
59 #[serde(rename = "not")]
60 Not,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum ResourceCheck {
66 #[serde(rename = "reachable")]
67 Reachable,
68 #[serde(rename = "content_contains")]
69 ContentContains { text: String },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum PolicyAction {
75 #[serde(rename = "allow")]
76 Allow,
77 #[serde(rename = "deny")]
78 Deny { reason: String },
79 #[serde(rename = "warn")]
80 Warn { message: String },
81 #[serde(rename = "log")]
82 Log { level: String },
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PolicyRule {
88 pub name: String,
89 pub description: Option<String>,
90 pub condition: RuleCondition,
91 pub action: PolicyAction,
92 pub enabled: bool,
93 pub priority: i32,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Policy {
99 pub id: String,
100 pub name: String,
101 pub version: String,
102 pub description: Option<String>,
103 pub rules: Vec<PolicyRule>,
104 pub metadata: Option<HashMap<String, Value>>,
105}
106
107#[derive(Debug, Clone)]
109pub struct PolicyEngine {
110 policies: Vec<Policy>,
111}
112
113impl PolicyEngine {
114 pub fn new() -> Self {
116 Self {
117 policies: Vec::new(),
118 }
119 }
120
121 pub fn add_policy(&mut self, policy: Policy) {
123 self.policies.push(policy);
124 }
125
126 pub fn evaluate(
128 &self,
129 manifest: &Manifest,
130 _context: &PolicyContext,
131 ) -> PolicyEvaluationResult {
132 let mut results = Vec::new();
133 let mut denied = false;
134 let mut warnings = Vec::new();
135
136 let mut all_rules: Vec<_> = self
138 .policies
139 .iter()
140 .flat_map(|p| p.rules.iter())
141 .filter(|r| r.enabled)
142 .collect();
143
144 all_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
145
146 for rule in all_rules {
147 let rule_result = self.evaluate_rule(rule, manifest);
148
149 let result = PolicyRuleResult {
150 rule_name: rule.name.clone(),
151 matched: rule_result.matched,
152 action: rule.action.clone(),
153 errors: rule_result.errors,
154 };
155
156 match &rule.action {
158 PolicyAction::Deny { reason: _ } if rule_result.matched => {
159 denied = true;
160 results.push(result);
161 break; }
163 PolicyAction::Warn { message } if rule_result.matched => {
164 warnings.push(message.clone());
165 results.push(result);
166 }
167 _ => {
168 if rule_result.matched {
169 results.push(result);
170 }
171 }
172 }
173 }
174
175 PolicyEvaluationResult {
176 allowed: !denied,
177 warnings,
178 rule_results: results,
179 }
180 }
181
182 fn evaluate_rule(&self, rule: &PolicyRule, manifest: &Manifest) -> RuleEvaluationResult {
184 let mut errors = Vec::new();
185
186 let matched = match self.evaluate_condition(&rule.condition, manifest) {
187 Ok(m) => m,
188 Err(e) => {
189 errors.push(e);
190 false
191 }
192 };
193
194 RuleEvaluationResult { matched, errors }
195 }
196
197 fn evaluate_condition(
199 &self,
200 condition: &RuleCondition,
201 manifest: &Manifest,
202 ) -> Result<bool, String> {
203 match condition {
204 RuleCondition::JsonPointer {
205 pointer,
206 operator,
207 value,
208 } => self.evaluate_json_pointer(manifest, pointer, *operator, value),
209 RuleCondition::Crypto { .. } => {
210 Ok(false)
212 }
213 RuleCondition::Resource { url, check } => self.evaluate_resource(url, check),
214 RuleCondition::Composite { operator, rules } => {
215 self.evaluate_composite(operator, rules, manifest)
216 }
217 }
218 }
219
220 fn evaluate_json_pointer(
222 &self,
223 manifest: &Manifest,
224 pointer: &str,
225 operator: CompareOp,
226 expected: &Value,
227 ) -> Result<bool, String> {
228 let manifest_json = serde_json::to_value(manifest)
230 .map_err(|e| format!("Failed to serialize manifest: {}", e))?;
231
232 let actual = match pointer {
235 "/kyaVersion" => manifest_json
236 .get("kyaVersion")
237 .cloned()
238 .ok_or_else(|| "kyaVersion not found".to_string())?,
239 "/agentId" => manifest_json
240 .get("agentId")
241 .cloned()
242 .ok_or_else(|| "agentId not found".to_string())?,
243 _ => {
244 return Err(format!("JSON pointer '{}' not yet supported", pointer));
245 }
246 };
247
248 match operator {
250 CompareOp::Equals => Ok(actual == *expected),
251 CompareOp::NotEquals => Ok(actual != *expected),
252 CompareOp::GreaterThan => {
253 let a = actual.as_f64().ok_or("Value is not a number")?;
254 let b = expected.as_f64().ok_or("Expected value is not a number")?;
255 Ok(a > b)
256 }
257 CompareOp::GreaterThanOrEqual => {
258 let a = actual.as_f64().ok_or("Value is not a number")?;
259 let b = expected.as_f64().ok_or("Expected value is not a number")?;
260 Ok(a >= b)
261 }
262 CompareOp::LessThan => {
263 let a = actual.as_f64().ok_or("Value is not a number")?;
264 let b = expected.as_f64().ok_or("Expected value is not a number")?;
265 Ok(a < b)
266 }
267 CompareOp::LessThanOrEqual => {
268 let a = actual.as_f64().ok_or("Value is not a number")?;
269 let b = expected.as_f64().ok_or("Expected value is not a number")?;
270 Ok(a <= b)
271 }
272 CompareOp::In => {
273 let arr = expected
274 .as_array()
275 .ok_or("Expected value is not an array")?;
276 Ok(arr.contains(&actual))
277 }
278 CompareOp::NotIn => {
279 let arr = expected
280 .as_array()
281 .ok_or("Expected value is not an array")?;
282 Ok(!arr.contains(&actual))
283 }
284 CompareOp::Contains => {
285 let str_actual = actual.as_str().ok_or("Value is not a string")?;
286 let str_expected = expected.as_str().ok_or("Expected value is not a string")?;
287 Ok(str_actual.contains(str_expected))
288 }
289 }
290 }
291
292 fn evaluate_resource(&self, _url: &str, check: &ResourceCheck) -> Result<bool, String> {
294 match check {
295 ResourceCheck::Reachable => {
296 Ok(true)
298 }
299 ResourceCheck::ContentContains { text: _ } => {
300 Ok(true)
302 }
303 }
304 }
305
306 fn evaluate_composite(
308 &self,
309 operator: &LogicalOperator,
310 rules: &[RuleCondition],
311 manifest: &Manifest,
312 ) -> Result<bool, String> {
313 match operator {
314 LogicalOperator::And => {
315 for rule in rules {
316 if !self.evaluate_condition(rule, manifest)? {
317 return Ok(false);
318 }
319 }
320 Ok(true)
321 }
322 LogicalOperator::Or => {
323 for rule in rules {
324 if self.evaluate_condition(rule, manifest)? {
325 return Ok(true);
326 }
327 }
328 Ok(false)
329 }
330 LogicalOperator::Not => {
331 if rules.len() != 1 {
332 return Err("NOT operator requires exactly one rule".to_string());
333 }
334 Ok(!self.evaluate_condition(&rules[0], manifest)?)
335 }
336 }
337 }
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct PolicyEvaluationResult {
343 pub allowed: bool,
344 pub warnings: Vec<String>,
345 pub rule_results: Vec<PolicyRuleResult>,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct PolicyRuleResult {
351 pub rule_name: String,
352 pub matched: bool,
353 pub action: PolicyAction,
354 pub errors: Vec<String>,
355}
356
357struct RuleEvaluationResult {
359 matched: bool,
360 errors: Vec<String>,
361}
362
363impl Default for PolicyEngine {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::types::Manifest;
373
374 #[test]
375 fn test_json_pointer_equals() {
376 let engine = PolicyEngine::new();
377 let manifest = Manifest {
378 kya_version: "1.0".to_string(),
379 agent_id: "did:key:z6Mk...".to_string(),
380 verification_method: None,
381 proof: vec![],
382 max_transaction_value: None,
383 permitted_regions: None,
384 forbidden_regions: None,
385 };
386
387 let result = engine.evaluate_json_pointer(
388 &manifest,
389 "/kyaVersion",
390 CompareOp::Equals,
391 &serde_json::json!("1.0"),
392 );
393 assert!(result.is_ok());
394 assert!(result.unwrap());
395 }
396
397 #[test]
398 fn test_policy_serialization() {
399 let policy = Policy {
400 id: "test-policy".to_string(),
401 name: "Test Policy".to_string(),
402 version: "1.0.0".to_string(),
403 description: Some("A test policy".to_string()),
404 rules: vec![],
405 metadata: None,
406 };
407
408 let json = serde_json::to_string(&policy);
409 assert!(json.is_ok());
410 }
411}