mpl_core/
policy.rs

1//! Policy Engine Lite
2//!
3//! Simple rule-based policy engine for enforcing SType usage rules,
4//! access control, and semantic constraints.
5
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::stype::SType;
10
11/// Policy Engine for enforcing rules on SType operations
12#[derive(Debug, Clone, Default)]
13pub struct PolicyEngine {
14    /// Registered policies
15    policies: Vec<Policy>,
16    /// Default QoM profile
17    default_profile: Option<String>,
18}
19
20impl PolicyEngine {
21    /// Create a new policy engine
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Add a policy
27    pub fn add_policy(&mut self, policy: Policy) {
28        self.policies.push(policy);
29    }
30
31    /// Set default QoM profile
32    pub fn set_default_profile(&mut self, profile: impl Into<String>) {
33        self.default_profile = Some(profile.into());
34    }
35
36    /// Evaluate policies for an SType operation
37    pub fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
38        let mut decision = PolicyDecision::allow();
39
40        for policy in &self.policies {
41            if policy.matches(context) {
42                let result = policy.evaluate(context);
43                decision = decision.merge(result);
44
45                // Short-circuit on deny
46                if decision.action == PolicyAction::Deny {
47                    return decision;
48                }
49            }
50        }
51
52        decision
53    }
54
55    /// Get required QoM profile for an SType
56    pub fn required_profile(&self, stype: &SType) -> Option<&str> {
57        for policy in &self.policies {
58            if let Some(ref qom_rule) = policy.qom_override {
59                if policy.matches_stype(stype) {
60                    return Some(&qom_rule.profile);
61                }
62            }
63        }
64        self.default_profile.as_deref()
65    }
66
67    /// Load policies from configuration
68    pub fn from_config(config: PolicyConfig) -> Self {
69        let mut engine = Self::new();
70        engine.default_profile = config.default_profile;
71        engine.policies = config.policies;
72        engine
73    }
74}
75
76/// Policy definition
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct Policy {
79    /// Policy name
80    pub name: String,
81
82    /// Description
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub description: Option<String>,
85
86    /// SType patterns this policy applies to
87    #[serde(default)]
88    pub stype_patterns: Vec<StypePattern>,
89
90    /// Operation types this policy applies to
91    #[serde(default)]
92    pub operations: HashSet<Operation>,
93
94    /// Access control rules
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub access_control: Option<AccessControlRule>,
97
98    /// QoM profile override
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub qom_override: Option<QomOverride>,
101
102    /// Rate limiting
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub rate_limit: Option<RateLimit>,
105
106    /// Custom constraints
107    #[serde(default)]
108    pub constraints: Vec<Constraint>,
109
110    /// Priority (higher = evaluated first)
111    #[serde(default)]
112    pub priority: i32,
113
114    /// Whether this policy is enabled
115    #[serde(default = "default_true")]
116    pub enabled: bool,
117}
118
119fn default_true() -> bool {
120    true
121}
122
123impl Policy {
124    /// Create a new policy
125    pub fn new(name: impl Into<String>) -> Self {
126        Self {
127            name: name.into(),
128            description: None,
129            stype_patterns: Vec::new(),
130            operations: HashSet::new(),
131            access_control: None,
132            qom_override: None,
133            rate_limit: None,
134            constraints: Vec::new(),
135            priority: 0,
136            enabled: true,
137        }
138    }
139
140    /// Add an SType pattern
141    pub fn with_stype_pattern(mut self, pattern: StypePattern) -> Self {
142        self.stype_patterns.push(pattern);
143        self
144    }
145
146    /// Add operations
147    pub fn with_operations(mut self, ops: impl IntoIterator<Item = Operation>) -> Self {
148        self.operations.extend(ops);
149        self
150    }
151
152    /// Set access control
153    pub fn with_access_control(mut self, rule: AccessControlRule) -> Self {
154        self.access_control = Some(rule);
155        self
156    }
157
158    /// Set QoM override
159    pub fn with_qom_override(mut self, profile: impl Into<String>) -> Self {
160        self.qom_override = Some(QomOverride {
161            profile: profile.into(),
162            reason: None,
163        });
164        self
165    }
166
167    /// Check if policy matches context
168    pub fn matches(&self, context: &PolicyContext) -> bool {
169        if !self.enabled {
170            return false;
171        }
172
173        // Check operations
174        if !self.operations.is_empty() && !self.operations.contains(&context.operation) {
175            return false;
176        }
177
178        // Check SType patterns
179        if !self.stype_patterns.is_empty() {
180            let matches_stype = self.stype_patterns.iter().any(|p| p.matches(&context.stype));
181            if !matches_stype {
182                return false;
183            }
184        }
185
186        true
187    }
188
189    /// Check if policy matches an SType
190    pub fn matches_stype(&self, stype: &SType) -> bool {
191        if self.stype_patterns.is_empty() {
192            return true;
193        }
194        self.stype_patterns.iter().any(|p| p.matches(stype))
195    }
196
197    /// Evaluate policy against context
198    pub fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
199        let mut decision = PolicyDecision::allow();
200
201        // Check access control
202        if let Some(ref acl) = self.access_control {
203            if !acl.is_allowed(&context.principal, &context.operation) {
204                return PolicyDecision::deny(format!(
205                    "Access denied: {} not allowed for operation {:?}",
206                    context.principal.as_deref().unwrap_or("anonymous"),
207                    context.operation
208                ));
209            }
210        }
211
212        // Check constraints
213        for constraint in &self.constraints {
214            if !constraint.evaluate(context) {
215                decision.warnings.push(format!(
216                    "Constraint '{}' not satisfied",
217                    constraint.name
218                ));
219                if constraint.required {
220                    return PolicyDecision::deny(format!(
221                        "Required constraint '{}' not satisfied",
222                        constraint.name
223                    ));
224                }
225            }
226        }
227
228        // Apply QoM override
229        if let Some(ref qom) = self.qom_override {
230            decision.required_profile = Some(qom.profile.clone());
231        }
232
233        decision
234    }
235}
236
237/// SType pattern for matching
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct StypePattern {
240    /// Namespace pattern (supports wildcards)
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub namespace: Option<String>,
243
244    /// Domain pattern (supports wildcards)
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub domain: Option<String>,
247
248    /// Name pattern (supports wildcards)
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub name: Option<String>,
251
252    /// Version constraint
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub version: Option<VersionConstraint>,
255}
256
257impl StypePattern {
258    /// Create a pattern matching all STypes
259    pub fn all() -> Self {
260        Self {
261            namespace: None,
262            domain: None,
263            name: None,
264            version: None,
265        }
266    }
267
268    /// Create a pattern matching a namespace
269    pub fn namespace(ns: impl Into<String>) -> Self {
270        Self {
271            namespace: Some(ns.into()),
272            domain: None,
273            name: None,
274            version: None,
275        }
276    }
277
278    /// Create a pattern matching namespace and domain
279    pub fn namespace_domain(ns: impl Into<String>, domain: impl Into<String>) -> Self {
280        Self {
281            namespace: Some(ns.into()),
282            domain: Some(domain.into()),
283            name: None,
284            version: None,
285        }
286    }
287
288    /// Check if pattern matches an SType
289    pub fn matches(&self, stype: &SType) -> bool {
290        // Check namespace
291        if let Some(ref ns_pattern) = self.namespace {
292            if !glob_match(ns_pattern, &stype.namespace) {
293                return false;
294            }
295        }
296
297        // Check domain
298        if let Some(ref domain_pattern) = self.domain {
299            if !glob_match(domain_pattern, &stype.domain) {
300                return false;
301            }
302        }
303
304        // Check name
305        if let Some(ref name_pattern) = self.name {
306            if !glob_match(name_pattern, &stype.name) {
307                return false;
308            }
309        }
310
311        // Check version
312        if let Some(ref version_constraint) = self.version {
313            if !version_constraint.matches(stype.major_version) {
314                return false;
315            }
316        }
317
318        true
319    }
320}
321
322/// Version constraint for pattern matching
323#[derive(Debug, Clone, Serialize, Deserialize)]
324#[serde(tag = "op")]
325pub enum VersionConstraint {
326    #[serde(rename = "eq")]
327    Eq { version: u32 },
328    #[serde(rename = "gte")]
329    Gte { version: u32 },
330    #[serde(rename = "lte")]
331    Lte { version: u32 },
332    #[serde(rename = "range")]
333    Range { min: u32, max: u32 },
334}
335
336impl VersionConstraint {
337    pub fn matches(&self, version: u32) -> bool {
338        match self {
339            VersionConstraint::Eq { version: v } => version == *v,
340            VersionConstraint::Gte { version: v } => version >= *v,
341            VersionConstraint::Lte { version: v } => version <= *v,
342            VersionConstraint::Range { min, max } => version >= *min && version <= *max,
343        }
344    }
345}
346
347/// Operation types
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum Operation {
351    /// Read/fetch schema or data
352    Read,
353    /// Create new payload
354    Create,
355    /// Update existing payload
356    Update,
357    /// Delete payload
358    Delete,
359    /// Validate payload
360    Validate,
361    /// Execute tool call
362    Execute,
363    /// Subscribe to events
364    Subscribe,
365}
366
367/// Access control rule
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct AccessControlRule {
370    /// Allowed principals (user/service IDs)
371    #[serde(default)]
372    pub allow: HashSet<String>,
373
374    /// Denied principals
375    #[serde(default)]
376    pub deny: HashSet<String>,
377
378    /// Allowed operations per principal
379    #[serde(default)]
380    pub operation_map: HashMap<String, HashSet<Operation>>,
381
382    /// Default action when no match
383    #[serde(default)]
384    pub default: AccessDefault,
385}
386
387impl AccessControlRule {
388    /// Check if principal is allowed for operation
389    pub fn is_allowed(&self, principal: &Option<String>, operation: &Operation) -> bool {
390        let principal = principal.as_deref().unwrap_or("anonymous");
391
392        // Check explicit deny
393        if self.deny.contains(principal) || self.deny.contains("*") {
394            return false;
395        }
396
397        // Check explicit allow
398        if self.allow.contains(principal) || self.allow.contains("*") {
399            // Check operation restrictions
400            if let Some(ops) = self.operation_map.get(principal) {
401                return ops.contains(operation);
402            }
403            return true;
404        }
405
406        // Default action
407        matches!(self.default, AccessDefault::Allow)
408    }
409}
410
411/// Default access action
412#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
413#[serde(rename_all = "snake_case")]
414pub enum AccessDefault {
415    Allow,
416    #[default]
417    Deny,
418}
419
420/// QoM profile override
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct QomOverride {
423    /// Required profile name
424    pub profile: String,
425
426    /// Reason for override
427    #[serde(skip_serializing_if = "Option::is_none")]
428    pub reason: Option<String>,
429}
430
431/// Rate limiting configuration
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct RateLimit {
434    /// Requests per window
435    pub requests: u32,
436
437    /// Window duration in seconds
438    pub window_seconds: u32,
439
440    /// Per-principal or global
441    #[serde(default)]
442    pub per_principal: bool,
443}
444
445/// Custom constraint
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct Constraint {
448    /// Constraint name
449    pub name: String,
450
451    /// Constraint expression (simplified)
452    pub expression: ConstraintExpr,
453
454    /// Whether constraint is required (deny) or advisory (warn)
455    #[serde(default)]
456    pub required: bool,
457}
458
459impl Constraint {
460    /// Evaluate constraint against context
461    pub fn evaluate(&self, context: &PolicyContext) -> bool {
462        self.expression.evaluate(context)
463    }
464}
465
466/// Constraint expression
467#[derive(Debug, Clone, Serialize, Deserialize)]
468#[serde(tag = "type")]
469pub enum ConstraintExpr {
470    /// Check if metadata contains key
471    #[serde(rename = "has_metadata")]
472    HasMetadata { key: String },
473
474    /// Check metadata value
475    #[serde(rename = "metadata_equals")]
476    MetadataEquals { key: String, value: String },
477
478    /// Check payload size
479    #[serde(rename = "max_payload_size")]
480    MaxPayloadSize { bytes: usize },
481
482    /// Always pass
483    #[serde(rename = "always")]
484    Always,
485
486    /// Always fail
487    #[serde(rename = "never")]
488    Never,
489}
490
491impl ConstraintExpr {
492    pub fn evaluate(&self, context: &PolicyContext) -> bool {
493        match self {
494            ConstraintExpr::HasMetadata { key } => {
495                context.metadata.contains_key(key)
496            }
497            ConstraintExpr::MetadataEquals { key, value } => {
498                context.metadata.get(key).map(|v| v == value).unwrap_or(false)
499            }
500            ConstraintExpr::MaxPayloadSize { bytes } => {
501                context.payload_size.map(|s| s <= *bytes).unwrap_or(true)
502            }
503            ConstraintExpr::Always => true,
504            ConstraintExpr::Never => false,
505        }
506    }
507}
508
509/// Context for policy evaluation
510#[derive(Debug, Clone)]
511pub struct PolicyContext {
512    /// SType being operated on
513    pub stype: SType,
514
515    /// Operation type
516    pub operation: Operation,
517
518    /// Principal (user/service ID)
519    pub principal: Option<String>,
520
521    /// Additional metadata
522    pub metadata: HashMap<String, String>,
523
524    /// Payload size (if applicable)
525    pub payload_size: Option<usize>,
526}
527
528impl PolicyContext {
529    /// Create a new policy context
530    pub fn new(stype: SType, operation: Operation) -> Self {
531        Self {
532            stype,
533            operation,
534            principal: None,
535            metadata: HashMap::new(),
536            payload_size: None,
537        }
538    }
539
540    /// Set principal
541    pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
542        self.principal = Some(principal.into());
543        self
544    }
545
546    /// Add metadata
547    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
548        self.metadata.insert(key.into(), value.into());
549        self
550    }
551
552    /// Set payload size
553    pub fn with_payload_size(mut self, size: usize) -> Self {
554        self.payload_size = Some(size);
555        self
556    }
557}
558
559/// Policy decision result
560#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct PolicyDecision {
562    /// Action to take
563    pub action: PolicyAction,
564
565    /// Reason for decision
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub reason: Option<String>,
568
569    /// Required QoM profile (if any)
570    #[serde(skip_serializing_if = "Option::is_none")]
571    pub required_profile: Option<String>,
572
573    /// Warnings (non-blocking)
574    #[serde(default, skip_serializing_if = "Vec::is_empty")]
575    pub warnings: Vec<String>,
576}
577
578impl PolicyDecision {
579    /// Create an allow decision
580    pub fn allow() -> Self {
581        Self {
582            action: PolicyAction::Allow,
583            reason: None,
584            required_profile: None,
585            warnings: Vec::new(),
586        }
587    }
588
589    /// Create a deny decision
590    pub fn deny(reason: impl Into<String>) -> Self {
591        Self {
592            action: PolicyAction::Deny,
593            reason: Some(reason.into()),
594            required_profile: None,
595            warnings: Vec::new(),
596        }
597    }
598
599    /// Merge with another decision (more restrictive wins)
600    pub fn merge(mut self, other: PolicyDecision) -> Self {
601        // Deny takes precedence
602        if other.action == PolicyAction::Deny {
603            return other;
604        }
605
606        // Merge warnings
607        self.warnings.extend(other.warnings);
608
609        // Use more specific profile
610        if other.required_profile.is_some() {
611            self.required_profile = other.required_profile;
612        }
613
614        self
615    }
616
617    /// Check if allowed
618    pub fn is_allowed(&self) -> bool {
619        self.action == PolicyAction::Allow
620    }
621}
622
623/// Policy action
624#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625#[serde(rename_all = "snake_case")]
626pub enum PolicyAction {
627    Allow,
628    Deny,
629}
630
631/// Policy configuration for loading from file
632#[derive(Debug, Clone, Default, Serialize, Deserialize)]
633pub struct PolicyConfig {
634    /// Default QoM profile
635    #[serde(skip_serializing_if = "Option::is_none")]
636    pub default_profile: Option<String>,
637
638    /// Policies
639    #[serde(default)]
640    pub policies: Vec<Policy>,
641}
642
643// Simple glob matching (supports * wildcard)
644fn glob_match(pattern: &str, text: &str) -> bool {
645    if pattern == "*" {
646        return true;
647    }
648
649    if !pattern.contains('*') {
650        return pattern == text;
651    }
652
653    let parts: Vec<&str> = pattern.split('*').collect();
654    let mut pos = 0;
655
656    for (i, part) in parts.iter().enumerate() {
657        if part.is_empty() {
658            continue;
659        }
660
661        match text[pos..].find(part) {
662            Some(idx) => {
663                // First part must match at start
664                if i == 0 && idx != 0 {
665                    return false;
666                }
667                pos += idx + part.len();
668            }
669            None => return false,
670        }
671    }
672
673    // Last part must match at end
674    if let Some(last) = parts.last() {
675        if !last.is_empty() && !text.ends_with(last) {
676            return false;
677        }
678    }
679
680    true
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686
687    fn test_stype() -> SType {
688        SType {
689            namespace: "eval".to_string(),
690            domain: "rag".to_string(),
691            name: "RAGQuery".to_string(),
692            major_version: 1,
693        }
694    }
695
696    #[test]
697    fn test_pattern_all() {
698        let pattern = StypePattern::all();
699        assert!(pattern.matches(&test_stype()));
700    }
701
702    #[test]
703    fn test_pattern_namespace() {
704        let pattern = StypePattern::namespace("eval");
705        assert!(pattern.matches(&test_stype()));
706
707        let pattern = StypePattern::namespace("org");
708        assert!(!pattern.matches(&test_stype()));
709    }
710
711    #[test]
712    fn test_pattern_wildcard() {
713        let pattern = StypePattern {
714            namespace: Some("ev*".to_string()),
715            domain: None,
716            name: None,
717            version: None,
718        };
719        assert!(pattern.matches(&test_stype()));
720    }
721
722    #[test]
723    fn test_policy_engine() {
724        let mut engine = PolicyEngine::new();
725
726        // Add policy for eval namespace requiring strict QoM
727        let policy = Policy::new("eval-strict")
728            .with_stype_pattern(StypePattern::namespace("eval"))
729            .with_qom_override("qom-strict-argcheck");
730
731        engine.add_policy(policy);
732
733        let context = PolicyContext::new(test_stype(), Operation::Execute);
734        let decision = engine.evaluate(&context);
735
736        assert!(decision.is_allowed());
737        assert_eq!(decision.required_profile, Some("qom-strict-argcheck".to_string()));
738    }
739
740    #[test]
741    fn test_access_control_deny() {
742        let mut engine = PolicyEngine::new();
743
744        let policy = Policy::new("restricted")
745            .with_stype_pattern(StypePattern::namespace("eval"))
746            .with_access_control(AccessControlRule {
747                allow: HashSet::from(["admin".to_string()]),
748                deny: HashSet::new(),
749                operation_map: HashMap::new(),
750                default: AccessDefault::Deny,
751            });
752
753        engine.add_policy(policy);
754
755        // Anonymous user should be denied
756        let context = PolicyContext::new(test_stype(), Operation::Execute);
757        let decision = engine.evaluate(&context);
758        assert!(!decision.is_allowed());
759
760        // Admin should be allowed
761        let context = PolicyContext::new(test_stype(), Operation::Execute)
762            .with_principal("admin");
763        let decision = engine.evaluate(&context);
764        assert!(decision.is_allowed());
765    }
766
767    #[test]
768    fn test_glob_matching() {
769        assert!(glob_match("*", "anything"));
770        assert!(glob_match("eval", "eval"));
771        assert!(!glob_match("eval", "org"));
772        assert!(glob_match("ev*", "eval"));
773        assert!(glob_match("*val", "eval"));
774        assert!(glob_match("e*l", "eval"));
775        assert!(glob_match("*a*", "eval"));
776    }
777}