1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::stype::SType;
10
11#[derive(Debug, Clone, Default)]
13pub struct PolicyEngine {
14 policies: Vec<Policy>,
16 default_profile: Option<String>,
18}
19
20impl PolicyEngine {
21 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn add_policy(&mut self, policy: Policy) {
28 self.policies.push(policy);
29 }
30
31 pub fn set_default_profile(&mut self, profile: impl Into<String>) {
33 self.default_profile = Some(profile.into());
34 }
35
36 pub fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
38 let mut decision = PolicyDecision::allow();
39
40 for policy in &self.policies {
41 if policy.matches(context) {
42 let result = policy.evaluate(context);
43 decision = decision.merge(result);
44
45 if decision.action == PolicyAction::Deny {
47 return decision;
48 }
49 }
50 }
51
52 decision
53 }
54
55 pub fn required_profile(&self, stype: &SType) -> Option<&str> {
57 for policy in &self.policies {
58 if let Some(ref qom_rule) = policy.qom_override {
59 if policy.matches_stype(stype) {
60 return Some(&qom_rule.profile);
61 }
62 }
63 }
64 self.default_profile.as_deref()
65 }
66
67 pub fn from_config(config: PolicyConfig) -> Self {
69 let mut engine = Self::new();
70 engine.default_profile = config.default_profile;
71 engine.policies = config.policies;
72 engine
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct Policy {
79 pub name: String,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub description: Option<String>,
85
86 #[serde(default)]
88 pub stype_patterns: Vec<StypePattern>,
89
90 #[serde(default)]
92 pub operations: HashSet<Operation>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub access_control: Option<AccessControlRule>,
97
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub qom_override: Option<QomOverride>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub rate_limit: Option<RateLimit>,
105
106 #[serde(default)]
108 pub constraints: Vec<Constraint>,
109
110 #[serde(default)]
112 pub priority: i32,
113
114 #[serde(default = "default_true")]
116 pub enabled: bool,
117}
118
119fn default_true() -> bool {
120 true
121}
122
123impl Policy {
124 pub fn new(name: impl Into<String>) -> Self {
126 Self {
127 name: name.into(),
128 description: None,
129 stype_patterns: Vec::new(),
130 operations: HashSet::new(),
131 access_control: None,
132 qom_override: None,
133 rate_limit: None,
134 constraints: Vec::new(),
135 priority: 0,
136 enabled: true,
137 }
138 }
139
140 pub fn with_stype_pattern(mut self, pattern: StypePattern) -> Self {
142 self.stype_patterns.push(pattern);
143 self
144 }
145
146 pub fn with_operations(mut self, ops: impl IntoIterator<Item = Operation>) -> Self {
148 self.operations.extend(ops);
149 self
150 }
151
152 pub fn with_access_control(mut self, rule: AccessControlRule) -> Self {
154 self.access_control = Some(rule);
155 self
156 }
157
158 pub fn with_qom_override(mut self, profile: impl Into<String>) -> Self {
160 self.qom_override = Some(QomOverride {
161 profile: profile.into(),
162 reason: None,
163 });
164 self
165 }
166
167 pub fn matches(&self, context: &PolicyContext) -> bool {
169 if !self.enabled {
170 return false;
171 }
172
173 if !self.operations.is_empty() && !self.operations.contains(&context.operation) {
175 return false;
176 }
177
178 if !self.stype_patterns.is_empty() {
180 let matches_stype = self.stype_patterns.iter().any(|p| p.matches(&context.stype));
181 if !matches_stype {
182 return false;
183 }
184 }
185
186 true
187 }
188
189 pub fn matches_stype(&self, stype: &SType) -> bool {
191 if self.stype_patterns.is_empty() {
192 return true;
193 }
194 self.stype_patterns.iter().any(|p| p.matches(stype))
195 }
196
197 pub fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
199 let mut decision = PolicyDecision::allow();
200
201 if let Some(ref acl) = self.access_control {
203 if !acl.is_allowed(&context.principal, &context.operation) {
204 return PolicyDecision::deny(format!(
205 "Access denied: {} not allowed for operation {:?}",
206 context.principal.as_deref().unwrap_or("anonymous"),
207 context.operation
208 ));
209 }
210 }
211
212 for constraint in &self.constraints {
214 if !constraint.evaluate(context) {
215 decision.warnings.push(format!(
216 "Constraint '{}' not satisfied",
217 constraint.name
218 ));
219 if constraint.required {
220 return PolicyDecision::deny(format!(
221 "Required constraint '{}' not satisfied",
222 constraint.name
223 ));
224 }
225 }
226 }
227
228 if let Some(ref qom) = self.qom_override {
230 decision.required_profile = Some(qom.profile.clone());
231 }
232
233 decision
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct StypePattern {
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub namespace: Option<String>,
243
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub domain: Option<String>,
247
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub name: Option<String>,
251
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub version: Option<VersionConstraint>,
255}
256
257impl StypePattern {
258 pub fn all() -> Self {
260 Self {
261 namespace: None,
262 domain: None,
263 name: None,
264 version: None,
265 }
266 }
267
268 pub fn namespace(ns: impl Into<String>) -> Self {
270 Self {
271 namespace: Some(ns.into()),
272 domain: None,
273 name: None,
274 version: None,
275 }
276 }
277
278 pub fn namespace_domain(ns: impl Into<String>, domain: impl Into<String>) -> Self {
280 Self {
281 namespace: Some(ns.into()),
282 domain: Some(domain.into()),
283 name: None,
284 version: None,
285 }
286 }
287
288 pub fn matches(&self, stype: &SType) -> bool {
290 if let Some(ref ns_pattern) = self.namespace {
292 if !glob_match(ns_pattern, &stype.namespace) {
293 return false;
294 }
295 }
296
297 if let Some(ref domain_pattern) = self.domain {
299 if !glob_match(domain_pattern, &stype.domain) {
300 return false;
301 }
302 }
303
304 if let Some(ref name_pattern) = self.name {
306 if !glob_match(name_pattern, &stype.name) {
307 return false;
308 }
309 }
310
311 if let Some(ref version_constraint) = self.version {
313 if !version_constraint.matches(stype.major_version) {
314 return false;
315 }
316 }
317
318 true
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324#[serde(tag = "op")]
325pub enum VersionConstraint {
326 #[serde(rename = "eq")]
327 Eq { version: u32 },
328 #[serde(rename = "gte")]
329 Gte { version: u32 },
330 #[serde(rename = "lte")]
331 Lte { version: u32 },
332 #[serde(rename = "range")]
333 Range { min: u32, max: u32 },
334}
335
336impl VersionConstraint {
337 pub fn matches(&self, version: u32) -> bool {
338 match self {
339 VersionConstraint::Eq { version: v } => version == *v,
340 VersionConstraint::Gte { version: v } => version >= *v,
341 VersionConstraint::Lte { version: v } => version <= *v,
342 VersionConstraint::Range { min, max } => version >= *min && version <= *max,
343 }
344 }
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum Operation {
351 Read,
353 Create,
355 Update,
357 Delete,
359 Validate,
361 Execute,
363 Subscribe,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct AccessControlRule {
370 #[serde(default)]
372 pub allow: HashSet<String>,
373
374 #[serde(default)]
376 pub deny: HashSet<String>,
377
378 #[serde(default)]
380 pub operation_map: HashMap<String, HashSet<Operation>>,
381
382 #[serde(default)]
384 pub default: AccessDefault,
385}
386
387impl AccessControlRule {
388 pub fn is_allowed(&self, principal: &Option<String>, operation: &Operation) -> bool {
390 let principal = principal.as_deref().unwrap_or("anonymous");
391
392 if self.deny.contains(principal) || self.deny.contains("*") {
394 return false;
395 }
396
397 if self.allow.contains(principal) || self.allow.contains("*") {
399 if let Some(ops) = self.operation_map.get(principal) {
401 return ops.contains(operation);
402 }
403 return true;
404 }
405
406 matches!(self.default, AccessDefault::Allow)
408 }
409}
410
411#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
413#[serde(rename_all = "snake_case")]
414pub enum AccessDefault {
415 Allow,
416 #[default]
417 Deny,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct QomOverride {
423 pub profile: String,
425
426 #[serde(skip_serializing_if = "Option::is_none")]
428 pub reason: Option<String>,
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct RateLimit {
434 pub requests: u32,
436
437 pub window_seconds: u32,
439
440 #[serde(default)]
442 pub per_principal: bool,
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct Constraint {
448 pub name: String,
450
451 pub expression: ConstraintExpr,
453
454 #[serde(default)]
456 pub required: bool,
457}
458
459impl Constraint {
460 pub fn evaluate(&self, context: &PolicyContext) -> bool {
462 self.expression.evaluate(context)
463 }
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468#[serde(tag = "type")]
469pub enum ConstraintExpr {
470 #[serde(rename = "has_metadata")]
472 HasMetadata { key: String },
473
474 #[serde(rename = "metadata_equals")]
476 MetadataEquals { key: String, value: String },
477
478 #[serde(rename = "max_payload_size")]
480 MaxPayloadSize { bytes: usize },
481
482 #[serde(rename = "always")]
484 Always,
485
486 #[serde(rename = "never")]
488 Never,
489}
490
491impl ConstraintExpr {
492 pub fn evaluate(&self, context: &PolicyContext) -> bool {
493 match self {
494 ConstraintExpr::HasMetadata { key } => {
495 context.metadata.contains_key(key)
496 }
497 ConstraintExpr::MetadataEquals { key, value } => {
498 context.metadata.get(key).map(|v| v == value).unwrap_or(false)
499 }
500 ConstraintExpr::MaxPayloadSize { bytes } => {
501 context.payload_size.map(|s| s <= *bytes).unwrap_or(true)
502 }
503 ConstraintExpr::Always => true,
504 ConstraintExpr::Never => false,
505 }
506 }
507}
508
509#[derive(Debug, Clone)]
511pub struct PolicyContext {
512 pub stype: SType,
514
515 pub operation: Operation,
517
518 pub principal: Option<String>,
520
521 pub metadata: HashMap<String, String>,
523
524 pub payload_size: Option<usize>,
526}
527
528impl PolicyContext {
529 pub fn new(stype: SType, operation: Operation) -> Self {
531 Self {
532 stype,
533 operation,
534 principal: None,
535 metadata: HashMap::new(),
536 payload_size: None,
537 }
538 }
539
540 pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
542 self.principal = Some(principal.into());
543 self
544 }
545
546 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
548 self.metadata.insert(key.into(), value.into());
549 self
550 }
551
552 pub fn with_payload_size(mut self, size: usize) -> Self {
554 self.payload_size = Some(size);
555 self
556 }
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct PolicyDecision {
562 pub action: PolicyAction,
564
565 #[serde(skip_serializing_if = "Option::is_none")]
567 pub reason: Option<String>,
568
569 #[serde(skip_serializing_if = "Option::is_none")]
571 pub required_profile: Option<String>,
572
573 #[serde(default, skip_serializing_if = "Vec::is_empty")]
575 pub warnings: Vec<String>,
576}
577
578impl PolicyDecision {
579 pub fn allow() -> Self {
581 Self {
582 action: PolicyAction::Allow,
583 reason: None,
584 required_profile: None,
585 warnings: Vec::new(),
586 }
587 }
588
589 pub fn deny(reason: impl Into<String>) -> Self {
591 Self {
592 action: PolicyAction::Deny,
593 reason: Some(reason.into()),
594 required_profile: None,
595 warnings: Vec::new(),
596 }
597 }
598
599 pub fn merge(mut self, other: PolicyDecision) -> Self {
601 if other.action == PolicyAction::Deny {
603 return other;
604 }
605
606 self.warnings.extend(other.warnings);
608
609 if other.required_profile.is_some() {
611 self.required_profile = other.required_profile;
612 }
613
614 self
615 }
616
617 pub fn is_allowed(&self) -> bool {
619 self.action == PolicyAction::Allow
620 }
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625#[serde(rename_all = "snake_case")]
626pub enum PolicyAction {
627 Allow,
628 Deny,
629}
630
631#[derive(Debug, Clone, Default, Serialize, Deserialize)]
633pub struct PolicyConfig {
634 #[serde(skip_serializing_if = "Option::is_none")]
636 pub default_profile: Option<String>,
637
638 #[serde(default)]
640 pub policies: Vec<Policy>,
641}
642
643fn glob_match(pattern: &str, text: &str) -> bool {
645 if pattern == "*" {
646 return true;
647 }
648
649 if !pattern.contains('*') {
650 return pattern == text;
651 }
652
653 let parts: Vec<&str> = pattern.split('*').collect();
654 let mut pos = 0;
655
656 for (i, part) in parts.iter().enumerate() {
657 if part.is_empty() {
658 continue;
659 }
660
661 match text[pos..].find(part) {
662 Some(idx) => {
663 if i == 0 && idx != 0 {
665 return false;
666 }
667 pos += idx + part.len();
668 }
669 None => return false,
670 }
671 }
672
673 if let Some(last) = parts.last() {
675 if !last.is_empty() && !text.ends_with(last) {
676 return false;
677 }
678 }
679
680 true
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686
687 fn test_stype() -> SType {
688 SType {
689 namespace: "eval".to_string(),
690 domain: "rag".to_string(),
691 name: "RAGQuery".to_string(),
692 major_version: 1,
693 }
694 }
695
696 #[test]
697 fn test_pattern_all() {
698 let pattern = StypePattern::all();
699 assert!(pattern.matches(&test_stype()));
700 }
701
702 #[test]
703 fn test_pattern_namespace() {
704 let pattern = StypePattern::namespace("eval");
705 assert!(pattern.matches(&test_stype()));
706
707 let pattern = StypePattern::namespace("org");
708 assert!(!pattern.matches(&test_stype()));
709 }
710
711 #[test]
712 fn test_pattern_wildcard() {
713 let pattern = StypePattern {
714 namespace: Some("ev*".to_string()),
715 domain: None,
716 name: None,
717 version: None,
718 };
719 assert!(pattern.matches(&test_stype()));
720 }
721
722 #[test]
723 fn test_policy_engine() {
724 let mut engine = PolicyEngine::new();
725
726 let policy = Policy::new("eval-strict")
728 .with_stype_pattern(StypePattern::namespace("eval"))
729 .with_qom_override("qom-strict-argcheck");
730
731 engine.add_policy(policy);
732
733 let context = PolicyContext::new(test_stype(), Operation::Execute);
734 let decision = engine.evaluate(&context);
735
736 assert!(decision.is_allowed());
737 assert_eq!(decision.required_profile, Some("qom-strict-argcheck".to_string()));
738 }
739
740 #[test]
741 fn test_access_control_deny() {
742 let mut engine = PolicyEngine::new();
743
744 let policy = Policy::new("restricted")
745 .with_stype_pattern(StypePattern::namespace("eval"))
746 .with_access_control(AccessControlRule {
747 allow: HashSet::from(["admin".to_string()]),
748 deny: HashSet::new(),
749 operation_map: HashMap::new(),
750 default: AccessDefault::Deny,
751 });
752
753 engine.add_policy(policy);
754
755 let context = PolicyContext::new(test_stype(), Operation::Execute);
757 let decision = engine.evaluate(&context);
758 assert!(!decision.is_allowed());
759
760 let context = PolicyContext::new(test_stype(), Operation::Execute)
762 .with_principal("admin");
763 let decision = engine.evaluate(&context);
764 assert!(decision.is_allowed());
765 }
766
767 #[test]
768 fn test_glob_matching() {
769 assert!(glob_match("*", "anything"));
770 assert!(glob_match("eval", "eval"));
771 assert!(!glob_match("eval", "org"));
772 assert!(glob_match("ev*", "eval"));
773 assert!(glob_match("*val", "eval"));
774 assert!(glob_match("e*l", "eval"));
775 assert!(glob_match("*a*", "eval"));
776 }
777}