Skip to main content

enact_core/policy/
tenant_policy.rs

1//! Tenant Policy - Multi-tenant isolation and limits
2
3use super::{PolicyContext, PolicyDecision, PolicyEvaluator};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Tenant limits
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TenantLimits {
10    /// Maximum concurrent executions
11    pub max_concurrent_executions: Option<usize>,
12    /// Maximum executions per day
13    pub max_executions_per_day: Option<usize>,
14    /// Maximum LLM tokens per day
15    pub max_tokens_per_day: Option<usize>,
16    /// Maximum storage in bytes
17    pub max_storage_bytes: Option<usize>,
18}
19
20impl Default for TenantLimits {
21    fn default() -> Self {
22        Self {
23            max_concurrent_executions: Some(10),
24            max_executions_per_day: Some(1000),
25            max_tokens_per_day: Some(1_000_000),
26            max_storage_bytes: Some(1024 * 1024 * 1024), // 1GB
27        }
28    }
29}
30
31impl TenantLimits {
32    /// Create unlimited tenant limits (for enterprise/unlimited tier)
33    pub fn unlimited() -> Self {
34        Self {
35            max_concurrent_executions: None,
36            max_executions_per_day: None,
37            max_tokens_per_day: None,
38            max_storage_bytes: None,
39        }
40    }
41
42    /// Create free tier limits
43    pub fn free_tier() -> Self {
44        Self {
45            max_concurrent_executions: Some(2),
46            max_executions_per_day: Some(100),
47            max_tokens_per_day: Some(50_000),
48            max_storage_bytes: Some(100 * 1024 * 1024), // 100MB
49        }
50    }
51
52    /// Create pro tier limits
53    pub fn pro_tier() -> Self {
54        Self {
55            max_concurrent_executions: Some(20),
56            max_executions_per_day: Some(10_000),
57            max_tokens_per_day: Some(10_000_000),
58            max_storage_bytes: Some(10 * 1024 * 1024 * 1024), // 10GB
59        }
60    }
61}
62
63/// Feature flags for tenant
64#[derive(Debug, Clone, Serialize, Deserialize, Default)]
65pub struct FeatureFlags {
66    /// Enabled features
67    pub enabled: HashSet<String>,
68    /// Disabled features (takes precedence)
69    pub disabled: HashSet<String>,
70}
71
72impl FeatureFlags {
73    /// Create new feature flags
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Enable a feature
79    pub fn enable(mut self, feature: impl Into<String>) -> Self {
80        self.enabled.insert(feature.into());
81        self
82    }
83
84    /// Disable a feature
85    pub fn disable(mut self, feature: impl Into<String>) -> Self {
86        self.disabled.insert(feature.into());
87        self
88    }
89
90    /// Check if a feature is enabled
91    pub fn is_enabled(&self, feature: &str) -> bool {
92        !self.disabled.contains(feature) && self.enabled.contains(feature)
93    }
94
95    /// Common feature flags
96    pub fn with_defaults() -> Self {
97        Self::new()
98            .enable("basic_execution")
99            .enable("tool_invocation")
100            .enable("streaming")
101    }
102
103    /// All features enabled
104    pub fn all_enabled() -> Self {
105        Self::new()
106            .enable("basic_execution")
107            .enable("tool_invocation")
108            .enable("streaming")
109            .enable("parallel_execution")
110            .enable("nested_execution")
111            .enable("custom_tools")
112            .enable("mcp_integration")
113            .enable("advanced_memory")
114    }
115}
116
117/// Tenant policy
118#[derive(Debug, Clone)]
119pub struct TenantPolicy {
120    /// Tenant limits
121    pub limits: TenantLimits,
122    /// Feature flags
123    pub features: FeatureFlags,
124    /// Allowed models
125    pub allowed_models: HashSet<String>,
126    /// Allowed tools
127    pub allowed_tools: Option<HashSet<String>>,
128    /// Isolation mode
129    pub isolation: TenantIsolation,
130}
131
132/// Tenant isolation mode
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
134pub enum TenantIsolation {
135    /// Shared resources (default)
136    #[default]
137    Shared,
138    /// Dedicated resources
139    Dedicated,
140    /// Strict isolation (no shared state)
141    Strict,
142}
143
144impl Default for TenantPolicy {
145    fn default() -> Self {
146        Self {
147            limits: TenantLimits::default(),
148            features: FeatureFlags::with_defaults(),
149            allowed_models: HashSet::new(),
150            allowed_tools: None, // None means all tools allowed
151            isolation: TenantIsolation::default(),
152        }
153    }
154}
155
156impl TenantPolicy {
157    /// Create a new tenant policy
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Set limits
163    pub fn with_limits(mut self, limits: TenantLimits) -> Self {
164        self.limits = limits;
165        self
166    }
167
168    /// Set features
169    pub fn with_features(mut self, features: FeatureFlags) -> Self {
170        self.features = features;
171        self
172    }
173
174    /// Allow a specific model
175    pub fn allow_model(mut self, model: impl Into<String>) -> Self {
176        self.allowed_models.insert(model.into());
177        self
178    }
179
180    /// Allow a specific tool
181    pub fn allow_tool(mut self, tool: impl Into<String>) -> Self {
182        self.allowed_tools
183            .get_or_insert_with(HashSet::new)
184            .insert(tool.into());
185        self
186    }
187
188    /// Set isolation mode
189    pub fn with_isolation(mut self, isolation: TenantIsolation) -> Self {
190        self.isolation = isolation;
191        self
192    }
193
194    /// Check if a model is allowed
195    pub fn is_model_allowed(&self, model: &str) -> bool {
196        self.allowed_models.is_empty() || self.allowed_models.contains(model)
197    }
198
199    /// Check if a tool is allowed
200    pub fn is_tool_allowed(&self, tool: &str) -> bool {
201        self.allowed_tools
202            .as_ref()
203            .map(|tools| tools.contains(tool))
204            .unwrap_or(true)
205    }
206}
207
208impl PolicyEvaluator for TenantPolicy {
209    fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
210        match &context.action {
211            super::PolicyAction::LlmCall { model } => {
212                if !self.is_model_allowed(model) {
213                    return PolicyDecision::Deny {
214                        reason: format!("Model '{}' is not allowed for this tenant", model),
215                    };
216                }
217                PolicyDecision::Allow
218            }
219            super::PolicyAction::InvokeTool { tool_name } => {
220                if !self.is_tool_allowed(tool_name) {
221                    return PolicyDecision::Deny {
222                        reason: format!("Tool '{}' is not allowed for this tenant", tool_name),
223                    };
224                }
225                PolicyDecision::Allow
226            }
227            _ => PolicyDecision::Allow,
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use std::collections::HashMap;
236
237    // ============ TenantLimits Tests ============
238
239    #[test]
240    fn test_tenant_limits_default() {
241        let limits = TenantLimits::default();
242        assert_eq!(limits.max_concurrent_executions, Some(10));
243        assert_eq!(limits.max_executions_per_day, Some(1000));
244        assert_eq!(limits.max_tokens_per_day, Some(1_000_000));
245        assert_eq!(limits.max_storage_bytes, Some(1024 * 1024 * 1024));
246    }
247
248    #[test]
249    fn test_tenant_limits_unlimited() {
250        let limits = TenantLimits::unlimited();
251        assert!(limits.max_concurrent_executions.is_none());
252        assert!(limits.max_executions_per_day.is_none());
253        assert!(limits.max_tokens_per_day.is_none());
254        assert!(limits.max_storage_bytes.is_none());
255    }
256
257    #[test]
258    fn test_tenant_limits_free_tier() {
259        let limits = TenantLimits::free_tier();
260        assert_eq!(limits.max_concurrent_executions, Some(2));
261        assert_eq!(limits.max_executions_per_day, Some(100));
262        assert_eq!(limits.max_tokens_per_day, Some(50_000));
263    }
264
265    #[test]
266    fn test_tenant_limits_pro_tier() {
267        let limits = TenantLimits::pro_tier();
268        assert_eq!(limits.max_concurrent_executions, Some(20));
269        assert_eq!(limits.max_executions_per_day, Some(10_000));
270        assert_eq!(limits.max_tokens_per_day, Some(10_000_000));
271    }
272
273    // ============ FeatureFlags Tests ============
274
275    #[test]
276    fn test_feature_flags_default() {
277        let flags = FeatureFlags::default();
278        assert!(flags.enabled.is_empty());
279        assert!(flags.disabled.is_empty());
280    }
281
282    #[test]
283    fn test_feature_flags_enable() {
284        let flags = FeatureFlags::new().enable("feature_a").enable("feature_b");
285
286        assert!(flags.is_enabled("feature_a"));
287        assert!(flags.is_enabled("feature_b"));
288        assert!(!flags.is_enabled("feature_c"));
289    }
290
291    #[test]
292    fn test_feature_flags_disable_overrides_enable() {
293        let flags = FeatureFlags::new().enable("feature_x").disable("feature_x");
294
295        // Disabled takes precedence
296        assert!(!flags.is_enabled("feature_x"));
297    }
298
299    #[test]
300    fn test_feature_flags_with_defaults() {
301        let flags = FeatureFlags::with_defaults();
302        assert!(flags.is_enabled("basic_execution"));
303        assert!(flags.is_enabled("tool_invocation"));
304        assert!(flags.is_enabled("streaming"));
305    }
306
307    #[test]
308    fn test_feature_flags_all_enabled() {
309        let flags = FeatureFlags::all_enabled();
310        assert!(flags.is_enabled("basic_execution"));
311        assert!(flags.is_enabled("parallel_execution"));
312        assert!(flags.is_enabled("mcp_integration"));
313        assert!(flags.is_enabled("advanced_memory"));
314    }
315
316    // ============ TenantIsolation Tests ============
317
318    #[test]
319    fn test_tenant_isolation_default() {
320        assert_eq!(TenantIsolation::default(), TenantIsolation::Shared);
321    }
322
323    // ============ TenantPolicy Tests ============
324
325    #[test]
326    fn test_tenant_policy_default() {
327        let policy = TenantPolicy::default();
328        assert!(policy.allowed_models.is_empty());
329        assert!(policy.allowed_tools.is_none());
330        assert_eq!(policy.isolation, TenantIsolation::Shared);
331    }
332
333    #[test]
334    fn test_tenant_policy_with_limits() {
335        let policy = TenantPolicy::new().with_limits(TenantLimits::free_tier());
336        assert_eq!(policy.limits.max_concurrent_executions, Some(2));
337    }
338
339    #[test]
340    fn test_tenant_policy_with_features() {
341        let policy = TenantPolicy::new().with_features(FeatureFlags::all_enabled());
342        assert!(policy.features.is_enabled("parallel_execution"));
343    }
344
345    #[test]
346    fn test_tenant_policy_allow_model() {
347        let policy = TenantPolicy::new()
348            .allow_model("gpt-4")
349            .allow_model("claude-3");
350
351        assert!(policy.is_model_allowed("gpt-4"));
352        assert!(policy.is_model_allowed("claude-3"));
353        assert!(!policy.is_model_allowed("unknown-model"));
354    }
355
356    #[test]
357    fn test_tenant_policy_allow_model_empty_allows_all() {
358        let policy = TenantPolicy::new();
359        // Empty allowed_models means all models are allowed
360        assert!(policy.is_model_allowed("any-model"));
361    }
362
363    #[test]
364    fn test_tenant_policy_allow_tool() {
365        let policy = TenantPolicy::new()
366            .allow_tool("web_search")
367            .allow_tool("calculator");
368
369        assert!(policy.is_tool_allowed("web_search"));
370        assert!(policy.is_tool_allowed("calculator"));
371        assert!(!policy.is_tool_allowed("file_system"));
372    }
373
374    #[test]
375    fn test_tenant_policy_no_tool_restriction_allows_all() {
376        let policy = TenantPolicy::new();
377        // None for allowed_tools means all tools are allowed
378        assert!(policy.is_tool_allowed("any-tool"));
379    }
380
381    #[test]
382    fn test_tenant_policy_with_isolation() {
383        let policy = TenantPolicy::new().with_isolation(TenantIsolation::Strict);
384        assert_eq!(policy.isolation, TenantIsolation::Strict);
385    }
386
387    // ============ PolicyEvaluator Tests ============
388
389    #[test]
390    fn test_tenant_policy_evaluate_llm_allowed() {
391        let policy = TenantPolicy::new(); // Empty allowed_models = all allowed
392        let context = PolicyContext {
393            tenant_id: Some("tenant-1".to_string()),
394            user_id: None,
395            action: super::super::PolicyAction::LlmCall {
396                model: "gpt-4".to_string(),
397            },
398            metadata: HashMap::new(),
399        };
400
401        let decision = policy.evaluate(&context);
402        assert!(decision.is_allowed());
403    }
404
405    #[test]
406    fn test_tenant_policy_evaluate_llm_denied() {
407        let policy = TenantPolicy::new().allow_model("claude-3"); // Only claude allowed
408        let context = PolicyContext {
409            tenant_id: Some("tenant-1".to_string()),
410            user_id: None,
411            action: super::super::PolicyAction::LlmCall {
412                model: "gpt-4".to_string(),
413            },
414            metadata: HashMap::new(),
415        };
416
417        let decision = policy.evaluate(&context);
418        assert!(decision.is_denied());
419    }
420
421    #[test]
422    fn test_tenant_policy_evaluate_tool_allowed() {
423        let policy = TenantPolicy::new(); // None = all tools allowed
424        let context = PolicyContext {
425            tenant_id: Some("tenant-1".to_string()),
426            user_id: None,
427            action: super::super::PolicyAction::InvokeTool {
428                tool_name: "any_tool".to_string(),
429            },
430            metadata: HashMap::new(),
431        };
432
433        let decision = policy.evaluate(&context);
434        assert!(decision.is_allowed());
435    }
436
437    #[test]
438    fn test_tenant_policy_evaluate_tool_denied() {
439        let policy = TenantPolicy::new().allow_tool("calculator"); // Only calculator allowed
440        let context = PolicyContext {
441            tenant_id: Some("tenant-1".to_string()),
442            user_id: None,
443            action: super::super::PolicyAction::InvokeTool {
444                tool_name: "web_search".to_string(),
445            },
446            metadata: HashMap::new(),
447        };
448
449        let decision = policy.evaluate(&context);
450        assert!(decision.is_denied());
451    }
452}