1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
7#[serde(rename_all = "snake_case")]
8pub enum RuleAction {
9 Allow,
11 Deny { reason: String },
13 RequireApproval { prompt: String },
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RuleCondition {
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub action_type: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 pub command_pattern: Option<String>,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub description_pattern: Option<String>,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub workflow_id: Option<String>,
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub time_range: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub cost_above: Option<f64>,
38}
39
40impl RuleCondition {
41 pub fn matches(&self, ctx: &RuleContext) -> bool {
43 if let Some(ref at) = self.action_type {
44 if !pattern_matches(&ctx.action_type, at) {
45 return false;
46 }
47 }
48 if let Some(ref cp) = self.command_pattern {
49 if !pattern_matches(&ctx.command, cp) {
50 return false;
51 }
52 }
53 if let Some(ref dp) = self.description_pattern {
54 if !pattern_matches(&ctx.description, dp) {
55 return false;
56 }
57 }
58 if let Some(ref wid) = self.workflow_id {
59 if ctx.workflow_id.as_deref() != Some(wid.as_str()) {
60 return false;
61 }
62 }
63 if let Some(threshold) = self.cost_above {
64 if ctx.estimated_cost.unwrap_or(0.0) <= threshold {
65 return false;
66 }
67 }
68 if let Some(ref range) = self.time_range {
69 if !check_time_range(range) {
70 return false;
71 }
72 }
73 true
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct RuleContext {
80 pub action_type: String,
81 pub command: String,
82 pub description: String,
83 pub workflow_id: Option<String>,
84 pub estimated_cost: Option<f64>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct PolicyRule {
90 pub name: String,
92 #[serde(default)]
94 pub description: String,
95 #[serde(default = "default_priority")]
97 pub priority: i32,
98 pub condition: RuleCondition,
100 pub action: RuleAction,
102 #[serde(default = "default_true")]
104 pub enabled: bool,
105}
106
107fn default_priority() -> i32 {
108 100
109}
110fn default_true() -> bool {
111 true
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct PolicyRuleSet {
117 #[serde(default)]
118 pub rules: Vec<PolicyRule>,
119}
120
121impl PolicyRuleSet {
122 pub fn from_yaml(content: &str) -> Result<Self, serde_yaml::Error> {
124 serde_yaml::from_str(content)
125 }
126
127 pub fn evaluate(&self, ctx: &RuleContext) -> Option<RuleAction> {
130 let mut sorted_rules: Vec<&PolicyRule> =
131 self.rules.iter().filter(|r| r.enabled).collect();
132 sorted_rules.sort_by_key(|r| r.priority);
133
134 for rule in sorted_rules {
135 if rule.condition.matches(ctx) {
136 return Some(rule.action.clone());
137 }
138 }
139 None
140 }
141
142 pub fn add_rule(&mut self, rule: PolicyRule) {
144 self.rules.push(rule);
145 }
146
147 pub fn remove_rule(&mut self, name: &str) -> bool {
149 let before = self.rules.len();
150 self.rules.retain(|r| r.name != name);
151 self.rules.len() < before
152 }
153}
154
155pub fn pattern_matches(text: &str, pattern: &str) -> bool {
157 let text_lower = text.to_lowercase();
158 let pattern_lower = pattern.to_lowercase();
159
160 if pattern_lower.contains('*') {
161 let parts: Vec<&str> = pattern_lower.split('*').collect();
162 let mut pos = 0;
163 for part in &parts {
164 if part.is_empty() {
165 continue;
166 }
167 match text_lower[pos..].find(part) {
168 Some(idx) => pos += idx + part.len(),
169 None => return false,
170 }
171 }
172 true
173 } else {
174 text_lower.contains(&pattern_lower)
175 }
176}
177
178fn check_time_range(range: &str) -> bool {
180 let parts: Vec<&str> = range.split('-').collect();
181 if parts.len() != 2 {
182 return false; }
184
185 let parse_time = |s: &str| -> Option<chrono::NaiveTime> {
186 chrono::NaiveTime::parse_from_str(s.trim(), "%H:%M").ok()
187 };
188
189 let (start, end) = match (parse_time(parts[0]), parse_time(parts[1])) {
190 (Some(s), Some(e)) => (s, e),
191 _ => return false, };
193
194 let now = chrono::Local::now().time();
195
196 if start <= end {
197 now >= start && now <= end
199 } else {
200 now >= start || now <= end
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn make_ctx(action_type: &str, command: &str, description: &str) -> RuleContext {
210 RuleContext {
211 action_type: action_type.into(),
212 command: command.into(),
213 description: description.into(),
214 workflow_id: None,
215 estimated_cost: None,
216 }
217 }
218
219 #[test]
220 fn test_pattern_matches_exact() {
221 assert!(pattern_matches("git push", "git push"));
222 assert!(!pattern_matches("git pull", "git push"));
223 }
224
225 #[test]
226 fn test_pattern_matches_wildcard() {
227 assert!(pattern_matches("deploy production", "deploy *"));
228 assert!(pattern_matches("deploy staging", "deploy *"));
229 }
230
231 #[test]
232 fn test_pattern_matches_case_insensitive() {
233 assert!(pattern_matches("DROP DATABASE", "drop database"));
234 }
235
236 #[test]
237 fn test_rule_condition_matches() {
238 let cond = RuleCondition {
239 action_type: Some("execute".into()),
240 command_pattern: Some("git push*".into()),
241 description_pattern: None,
242 workflow_id: None,
243 time_range: None,
244 cost_above: None,
245 };
246
247 let ctx = make_ctx("execute", "git push origin main", "Push to remote");
248 assert!(cond.matches(&ctx));
249
250 let ctx2 = make_ctx("read", "git push origin main", "Push to remote");
251 assert!(!ctx2.action_type.contains("execute") || cond.matches(&ctx2));
252 }
253
254 #[test]
255 fn test_rule_condition_cost_threshold() {
256 let cond = RuleCondition {
257 action_type: None,
258 command_pattern: None,
259 description_pattern: None,
260 workflow_id: None,
261 time_range: None,
262 cost_above: Some(1.0),
263 };
264
265 let mut ctx = make_ctx("api_call", "model invoke", "Call GPT-4");
266 ctx.estimated_cost = Some(0.5);
267 assert!(!cond.matches(&ctx));
268
269 ctx.estimated_cost = Some(1.5);
270 assert!(cond.matches(&ctx));
271 }
272
273 #[test]
274 fn test_policy_rule_set_evaluate() {
275 let rule_set = PolicyRuleSet {
276 rules: vec![
277 PolicyRule {
278 name: "block-production".into(),
279 description: "Block production deploys".into(),
280 priority: 1,
281 condition: RuleCondition {
282 action_type: None,
283 command_pattern: Some("deploy production*".into()),
284 description_pattern: None,
285 workflow_id: None,
286 time_range: None,
287 cost_above: None,
288 },
289 action: RuleAction::Deny {
290 reason: "Production deploys require manual approval flow".into(),
291 },
292 enabled: true,
293 },
294 PolicyRule {
295 name: "allow-staging".into(),
296 description: "Allow staging deploys".into(),
297 priority: 10,
298 condition: RuleCondition {
299 action_type: None,
300 command_pattern: Some("deploy staging*".into()),
301 description_pattern: None,
302 workflow_id: None,
303 time_range: None,
304 cost_above: None,
305 },
306 action: RuleAction::Allow,
307 enabled: true,
308 },
309 ],
310 };
311
312 let ctx = make_ctx("execute", "deploy production v2", "Deploy to prod");
313 let result = rule_set.evaluate(&ctx);
314 assert!(matches!(result, Some(RuleAction::Deny { .. })));
315
316 let ctx2 = make_ctx("execute", "deploy staging v2", "Deploy to staging");
317 let result2 = rule_set.evaluate(&ctx2);
318 assert!(matches!(result2, Some(RuleAction::Allow)));
319
320 let ctx3 = make_ctx("read", "git log", "View history");
321 let result3 = rule_set.evaluate(&ctx3);
322 assert!(result3.is_none());
323 }
324
325 #[test]
326 fn test_add_and_remove_rule() {
327 let mut rule_set = PolicyRuleSet { rules: vec![] };
328
329 rule_set.add_rule(PolicyRule {
330 name: "test-rule".into(),
331 description: "".into(),
332 priority: 50,
333 condition: RuleCondition {
334 action_type: None,
335 command_pattern: Some("*".into()),
336 description_pattern: None,
337 workflow_id: None,
338 time_range: None,
339 cost_above: None,
340 },
341 action: RuleAction::Allow,
342 enabled: true,
343 });
344
345 assert_eq!(rule_set.rules.len(), 1);
346 assert!(rule_set.remove_rule("test-rule"));
347 assert!(rule_set.rules.is_empty());
348 assert!(!rule_set.remove_rule("nonexistent"));
349 }
350
351 #[test]
352 fn test_disabled_rules_skipped() {
353 let rule_set = PolicyRuleSet {
354 rules: vec![PolicyRule {
355 name: "disabled-rule".into(),
356 description: "".into(),
357 priority: 1,
358 condition: RuleCondition {
359 action_type: None,
360 command_pattern: Some("*".into()),
361 description_pattern: None,
362 workflow_id: None,
363 time_range: None,
364 cost_above: None,
365 },
366 action: RuleAction::Deny {
367 reason: "should not fire".into(),
368 },
369 enabled: false,
370 }],
371 };
372
373 let ctx = make_ctx("execute", "anything", "anything");
374 assert!(rule_set.evaluate(&ctx).is_none());
375 }
376
377 #[test]
378 fn test_rule_set_from_yaml() {
379 let rule_set = PolicyRuleSet {
381 rules: vec![
382 PolicyRule {
383 name: "block-rm".into(),
384 description: "Block dangerous rm commands".into(),
385 priority: 1,
386 condition: RuleCondition {
387 action_type: None,
388 command_pattern: Some("rm -rf *".into()),
389 description_pattern: None,
390 workflow_id: None,
391 time_range: None,
392 cost_above: None,
393 },
394 action: RuleAction::Deny {
395 reason: "rm -rf is forbidden".into(),
396 },
397 enabled: true,
398 },
399 PolicyRule {
400 name: "allow-read".into(),
401 description: "".into(),
402 priority: 100,
403 condition: RuleCondition {
404 action_type: Some("read".into()),
405 command_pattern: None,
406 description_pattern: None,
407 workflow_id: None,
408 time_range: None,
409 cost_above: None,
410 },
411 action: RuleAction::Allow,
412 enabled: true,
413 },
414 ],
415 };
416
417 let yaml = serde_yaml::to_string(&rule_set).unwrap();
418 let parsed = PolicyRuleSet::from_yaml(&yaml).unwrap();
419 assert_eq!(parsed.rules.len(), 2);
420 assert_eq!(parsed.rules[0].name, "block-rm");
421 assert!(matches!(parsed.rules[1].action, RuleAction::Allow));
422 }
423
424 #[test]
425 fn test_rule_action_serialization() {
426 let allow = RuleAction::Allow;
427 let json = serde_json::to_string(&allow).unwrap();
428 assert_eq!(json, "\"allow\"");
429
430 let deny = RuleAction::Deny {
431 reason: "forbidden".into(),
432 };
433 let json = serde_json::to_string(&deny).unwrap();
434 assert!(json.contains("forbidden"));
435 }
436
437 #[test]
438 fn test_priority_ordering() {
439 let rule_set = PolicyRuleSet {
440 rules: vec![
441 PolicyRule {
442 name: "low-priority".into(),
443 description: "".into(),
444 priority: 100,
445 condition: RuleCondition {
446 action_type: None,
447 command_pattern: Some("test*".into()),
448 description_pattern: None,
449 workflow_id: None,
450 time_range: None,
451 cost_above: None,
452 },
453 action: RuleAction::Allow,
454 enabled: true,
455 },
456 PolicyRule {
457 name: "high-priority".into(),
458 description: "".into(),
459 priority: 1,
460 condition: RuleCondition {
461 action_type: None,
462 command_pattern: Some("test*".into()),
463 description_pattern: None,
464 workflow_id: None,
465 time_range: None,
466 cost_above: None,
467 },
468 action: RuleAction::Deny {
469 reason: "high priority wins".into(),
470 },
471 enabled: true,
472 },
473 ],
474 };
475
476 let ctx = make_ctx("execute", "test something", "test");
477 let result = rule_set.evaluate(&ctx);
478 assert!(matches!(result, Some(RuleAction::Deny { .. })));
479 }
480}